├── .github └── workflows │ └── main.yml ├── .gitignore ├── .pre-commit-config.yaml ├── COPYING ├── LICENSE.torchrec ├── README.md ├── common ├── __init__.py ├── batch.py ├── checkpointing │ ├── __init__.py │ └── snapshot.py ├── device.py ├── filesystem │ ├── __init__.py │ ├── test_infer_fs.py │ └── util.py ├── log_weights.py ├── modules │ └── embedding │ │ ├── config.py │ │ └── embedding.py ├── run_training.py ├── test_device.py ├── testing_utils.py ├── utils.py └── wandb.py ├── core ├── __init__.py ├── config │ ├── __init__.py │ ├── base_config.py │ ├── base_config_test.py │ ├── config_load.py │ ├── test_config_load.py │ └── training.py ├── custom_training_loop.py ├── debug_training_loop.py ├── loss_type.py ├── losses.py ├── metric_mixin.py ├── metrics.py ├── test_metrics.py ├── test_train_pipeline.py └── train_pipeline.py ├── images ├── init_venv.sh └── requirements.txt ├── machines ├── environment.py ├── get_env.py ├── is_venv.py └── list_ops.py ├── metrics ├── __init__.py ├── aggregation.py ├── auroc.py └── rce.py ├── ml_logging ├── __init__.py ├── absl_logging.py ├── test_torch_logging.py └── torch_logging.py ├── model.py ├── optimizers ├── __init__.py ├── config.py └── optimizer.py ├── projects ├── __init__.py ├── home │ └── recap │ │ ├── FEATURES.md │ │ ├── README.md │ │ ├── __init__.py │ │ ├── config.py │ │ ├── config │ │ ├── home_recap_2022 │ │ │ └── segdense.json │ │ └── local_prod.yaml │ │ ├── data │ │ ├── __init__.py │ │ ├── config.py │ │ ├── dataset.py │ │ ├── generate_random_data.py │ │ ├── preprocessors.py │ │ ├── tfe_parsing.py │ │ └── util.py │ │ ├── embedding │ │ └── config.py │ │ ├── main.py │ │ ├── model │ │ ├── __init__.py │ │ ├── config.py │ │ ├── entrypoint.py │ │ ├── feature_transform.py │ │ ├── mask_net.py │ │ ├── mlp.py │ │ ├── model_and_loss.py │ │ └── numeric_calibration.py │ │ ├── optimizer │ │ ├── __init__.py │ │ ├── config.py │ │ └── optimizer.py │ │ └── script │ │ ├── create_random_data.sh │ │ └── run_local.sh └── twhin │ ├── README.md │ ├── config.py │ ├── config │ └── local.yaml │ ├── data │ ├── config.py │ ├── data.py │ ├── edges.py │ ├── test_data.py │ └── test_edges.py │ ├── machines.yaml │ ├── metrics.py │ ├── models │ ├── config.py │ ├── models.py │ └── test_models.py │ ├── optimizer.py │ ├── run.py │ ├── scripts │ ├── docker_run.sh │ └── run_in_docker.sh │ └── test_optimizer.py ├── pyproject.toml ├── reader ├── __init__.py ├── dataset.py ├── dds.py ├── test_dataset.py ├── test_utils.py └── utils.py └── tools └── pq.py /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: Python package 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | matrix: 10 | python-version: ["3.10"] 11 | 12 | steps: 13 | - uses: actions/checkout@v3 14 | # - uses: pre-commit/action@v3.0.0 15 | # name: Run pre-commit checks (pylint/yapf/isort) 16 | # env: 17 | # SKIP: insert-license 18 | # with: 19 | # extra_args: --hook-stage push --all-files 20 | - uses: actions/setup-python@v4 21 | with: 22 | python-version: "3.10" 23 | cache: "pip" # caching pip dependencies 24 | - name: install packages 25 | run: | 26 | /usr/bin/python -m pip install --upgrade pip 27 | pip install --no-deps -r images/requirements.txt 28 | # - name: ssh access 29 | # uses: lhotari/action-upterm@v1 30 | # with: 31 | # limit-access-to-actor: true 32 | # limit-access-to-users: arashd 33 | - name: run tests 34 | run: | 35 | # Environment variables are reset in between steps. 36 | mkdir /tmp/github_testing 37 | ln -s $GITHUB_WORKSPACE /tmp/github_testing/tml 38 | export PYTHONPATH="/tmp/github_testing:$PYTHONPATH" 39 | pytest -vv 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Mac 2 | .DS_Store 3 | 4 | # Vim 5 | *.py.swp 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # Installer logs 29 | pip-log.txt 30 | pip-delete-this-directory.txt 31 | 32 | # Unit test / coverage reports 33 | .hypothesis 34 | 35 | venv 36 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pausan/cblack 3 | rev: release-22.3.0 4 | hooks: 5 | - id: cblack 6 | name: cblack 7 | description: "Black: The uncompromising Python code formatter - 2 space indent fork" 8 | entry: cblack . -l 100 9 | - repo: https://github.com/pre-commit/pre-commit-hooks 10 | rev: v2.3.0 11 | hooks: 12 | - id: trailing-whitespace 13 | - id: end-of-file-fixer 14 | - id: check-yaml 15 | - id: check-added-large-files 16 | - id: check-merge-conflict 17 | -------------------------------------------------------------------------------- /LICENSE.torchrec: -------------------------------------------------------------------------------- 1 | A few files here (where it is specifically noted in comments) are based on code from torchrec but 2 | adapted for our use. Torchrec license is below: 3 | 4 | 5 | BSD 3-Clause License 6 | 7 | Copyright (c) Meta Platforms, Inc. and affiliates. 8 | All rights reserved. 9 | 10 | Redistribution and use in source and binary forms, with or without 11 | modification, are permitted provided that the following conditions are met: 12 | 13 | * Redistributions of source code must retain the above copyright notice, this 14 | list of conditions and the following disclaimer. 15 | 16 | * Redistributions in binary form must reproduce the above copyright notice, 17 | this list of conditions and the following disclaimer in the documentation 18 | and/or other materials provided with the distribution. 19 | 20 | * Neither the name of the copyright holder nor the names of its 21 | contributors may be used to endorse or promote products derived from 22 | this software without specific prior written permission. 23 | 24 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 25 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 26 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 27 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 28 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 29 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 30 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 31 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 32 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 33 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This project open sources some of the ML models used at Twitter. 2 | 3 | Currently these are: 4 | 5 | 1. The "For You" Heavy Ranker (projects/home/recap). 6 | 7 | 2. TwHIN embeddings (projects/twhin) https://arxiv.org/abs/2202.05387 8 | 9 | 10 | This project can be run inside a python virtualenv. We have only tried this on Linux machines and because we use torchrec it works best with an Nvidia GPU. To setup run 11 | 12 | `./images/init_venv.sh` (Linux only). 13 | 14 | The READMEs of each project contain instructions about how to run each project. 15 | -------------------------------------------------------------------------------- /common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twitter/the-algorithm-ml/b85210863f7a94efded0ef5c5ccf4ff42767876c/common/__init__.py -------------------------------------------------------------------------------- /common/batch.py: -------------------------------------------------------------------------------- 1 | """Extension of torchrec.dataset.utils.Batch to cover any dataset. 2 | """ 3 | # flake8: noqa 4 | from __future__ import annotations 5 | from typing import Dict 6 | import abc 7 | from dataclasses import dataclass 8 | import dataclasses 9 | 10 | import torch 11 | from torchrec.streamable import Pipelineable 12 | 13 | 14 | class BatchBase(Pipelineable, abc.ABC): 15 | @abc.abstractmethod 16 | def as_dict(self) -> Dict: 17 | raise NotImplementedError 18 | 19 | def to(self, device: torch.device, non_blocking: bool = False): 20 | args = {} 21 | for feature_name, feature_value in self.as_dict().items(): 22 | args[feature_name] = feature_value.to(device=device, non_blocking=non_blocking) 23 | return self.__class__(**args) 24 | 25 | def record_stream(self, stream: torch.cuda.streams.Stream) -> None: 26 | for feature_value in self.as_dict().values(): 27 | feature_value.record_stream(stream) 28 | 29 | def pin_memory(self): 30 | args = {} 31 | for feature_name, feature_value in self.as_dict().items(): 32 | args[feature_name] = feature_value.pin_memory() 33 | return self.__class__(**args) 34 | 35 | def __repr__(self) -> str: 36 | def obj2str(v): 37 | return f"{v.size()}" if hasattr(v, "size") else f"{v.length_per_key()}" 38 | 39 | return "\n".join([f"{k}: {obj2str(v)}," for k, v in self.as_dict().items()]) 40 | 41 | @property 42 | def batch_size(self) -> int: 43 | for tensor in self.as_dict().values(): 44 | if tensor is None: 45 | continue 46 | if not isinstance(tensor, torch.Tensor): 47 | continue 48 | return tensor.shape[0] 49 | raise Exception("Could not determine batch size from tensors.") 50 | 51 | 52 | @dataclass 53 | class DataclassBatch(BatchBase): 54 | @classmethod 55 | def feature_names(cls): 56 | return list(cls.__dataclass_fields__.keys()) 57 | 58 | def as_dict(self): 59 | return { 60 | feature_name: getattr(self, feature_name) 61 | for feature_name in self.feature_names() 62 | if hasattr(self, feature_name) 63 | } 64 | 65 | @staticmethod 66 | def from_schema(name: str, schema): 67 | """Instantiates a custom batch subclass if all columns can be represented as a torch.Tensor.""" 68 | return dataclasses.make_dataclass( 69 | cls_name=name, 70 | fields=[(name, torch.Tensor, dataclasses.field(default=None)) for name in schema.names], 71 | bases=(DataclassBatch,), 72 | ) 73 | 74 | @staticmethod 75 | def from_fields(name: str, fields: dict): 76 | return dataclasses.make_dataclass( 77 | cls_name=name, 78 | fields=[(_name, _type, dataclasses.field(default=None)) for _name, _type in fields.items()], 79 | bases=(DataclassBatch,), 80 | ) 81 | 82 | 83 | class DictionaryBatch(BatchBase, dict): 84 | def as_dict(self) -> Dict: 85 | return self 86 | -------------------------------------------------------------------------------- /common/checkpointing/__init__.py: -------------------------------------------------------------------------------- 1 | from tml.common.checkpointing.snapshot import get_checkpoint, Snapshot 2 | -------------------------------------------------------------------------------- /common/device.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.distributed as dist 5 | 6 | 7 | def maybe_setup_tensorflow(): 8 | try: 9 | import tensorflow as tf 10 | except ImportError: 11 | pass 12 | else: 13 | tf.config.set_visible_devices([], "GPU") # disable tf gpu 14 | 15 | 16 | def setup_and_get_device(tf_ok: bool = True) -> torch.device: 17 | if tf_ok: 18 | maybe_setup_tensorflow() 19 | 20 | device = torch.device("cpu") 21 | backend = "gloo" 22 | if torch.cuda.is_available(): 23 | rank = os.environ["LOCAL_RANK"] 24 | device = torch.device(f"cuda:{rank}") 25 | backend = "nccl" 26 | torch.cuda.set_device(device) 27 | if not torch.distributed.is_initialized(): 28 | dist.init_process_group(backend) 29 | 30 | return device 31 | -------------------------------------------------------------------------------- /common/filesystem/__init__.py: -------------------------------------------------------------------------------- 1 | from tml.common.filesystem.util import infer_fs, is_gcs_fs, is_local_fs 2 | -------------------------------------------------------------------------------- /common/filesystem/test_infer_fs.py: -------------------------------------------------------------------------------- 1 | """Minimal test for infer_fs. 2 | 3 | Mostly a test that it returns an object 4 | """ 5 | from tml.common.filesystem import infer_fs 6 | 7 | 8 | def test_infer_fs(): 9 | local_path = "/tmp/local_path" 10 | gcs_path = "gs://somebucket/somepath" 11 | 12 | local_fs = infer_fs(local_path) 13 | gcs_fs = infer_fs(gcs_path) 14 | 15 | # This should return two different objects 16 | assert local_fs != gcs_fs 17 | -------------------------------------------------------------------------------- /common/filesystem/util.py: -------------------------------------------------------------------------------- 1 | """Utilities for interacting with the file systems.""" 2 | from fsspec.implementations.local import LocalFileSystem 3 | import gcsfs 4 | 5 | 6 | GCS_FS = gcsfs.GCSFileSystem(cache_timeout=-1) 7 | LOCAL_FS = LocalFileSystem() 8 | 9 | 10 | def infer_fs(path: str): 11 | if path.startswith("gs://"): 12 | return GCS_FS 13 | elif path.startswith("hdfs://"): 14 | # We can probably use pyarrow HDFS to support this. 15 | raise NotImplementedError("HDFS not yet supported") 16 | else: 17 | return LOCAL_FS 18 | 19 | 20 | def is_local_fs(fs): 21 | return fs == LOCAL_FS 22 | 23 | 24 | def is_gcs_fs(fs): 25 | return fs == GCS_FS 26 | -------------------------------------------------------------------------------- /common/log_weights.py: -------------------------------------------------------------------------------- 1 | """For logging model weights.""" 2 | import itertools 3 | from typing import Callable, Dict, List, Optional, Union 4 | 5 | from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined] 6 | import torch 7 | import torch.distributed as dist 8 | from torchrec.distributed.model_parallel import DistributedModelParallel 9 | 10 | 11 | def weights_to_log( 12 | model: torch.nn.Module, 13 | how_to_log: Optional[Union[Callable, Dict[str, Callable]]] = None, 14 | ): 15 | """Creates dict of reduced weights to log to give sense of training. 16 | 17 | Args: 18 | model: model to traverse. 19 | how_to_log: if a function, then applies this to every parameter, if a dict 20 | then only applies and logs specified parameters. 21 | 22 | """ 23 | if not how_to_log: 24 | return 25 | 26 | to_log = dict() 27 | named_parameters = model.named_parameters() 28 | logging.info(f"Using DMP: {isinstance(model, DistributedModelParallel)}") 29 | if isinstance(model, DistributedModelParallel): 30 | named_parameters = itertools.chain( 31 | named_parameters, model._dmp_wrapped_module.named_parameters() 32 | ) 33 | logging.info( 34 | f"Using dmp parameters: {list(name for name, _ in model._dmp_wrapped_module.named_parameters())}" 35 | ) 36 | for param_name, params in named_parameters: 37 | if callable(how_to_log): 38 | how = how_to_log 39 | else: 40 | how = how_to_log.get(param_name) # type: ignore[assignment] 41 | if not how: 42 | continue # type: ignore 43 | to_log[f"model/{how.__name__}/{param_name}"] = how(params.detach()).cpu().numpy() 44 | return to_log 45 | 46 | 47 | def log_ebc_norms( 48 | model_state_dict, 49 | ebc_keys: List[str], 50 | sample_size: int = 4_000_000, 51 | ) -> Dict[str, torch.Tensor]: 52 | """Logs the norms of the embedding tables as specified by ebc_keys. 53 | As of now, log average norm per rank. 54 | 55 | Args: 56 | model_state_dict: model.state_dict() 57 | ebc_keys: list of embedding keys from state_dict to log. Must contain full name, 58 | i.e. model.embeddings.ebc.embedding_bags.meta__user_id.weight 59 | sample_size: Limits number of rows per rank to compute average on to avoid OOM. 60 | """ 61 | norm_logs = dict() 62 | for emb_key in ebc_keys: 63 | norms = (torch.ones(1, dtype=torch.float32) * -1).to(torch.device(f"cuda:{dist.get_rank()}")) 64 | if emb_key in model_state_dict: 65 | emb_weight = model_state_dict[emb_key] 66 | try: 67 | emb_weight_tensor = emb_weight.local_tensor() 68 | except AttributeError as e: 69 | logging.info(e) 70 | emb_weight_tensor = emb_weight 71 | logging.info("Running Tensor.detach()") 72 | emb_weight_tensor = emb_weight_tensor.detach() 73 | sample_mask = torch.randperm(emb_weight_tensor.shape[0])[ 74 | : min(sample_size, emb_weight_tensor.shape[0]) 75 | ] 76 | # WARNING: .cpu() transfer executes malloc that may be the cause of memory leaks 77 | # Change sample_size if the you observe frequent OOM errors or remove weight logging. 78 | norms = emb_weight_tensor[sample_mask].cpu().norm(dim=1).to(torch.float32) 79 | logging.info(f"Norm shape before reduction: {norms.shape}", rank=-1) 80 | norms = norms.mean().to(torch.device(f"cuda:{dist.get_rank()}")) 81 | 82 | all_norms = [ 83 | torch.zeros(1, dtype=norms.dtype).to(norms.device) for _ in range(dist.get_world_size()) 84 | ] 85 | dist.all_gather(all_norms, norms) 86 | for idx, norm in enumerate(all_norms): 87 | if norm != -1.0: 88 | norm_logs[f"{emb_key}-norm-{idx}"] = norm 89 | logging.info(f"Norm Logs are {norm_logs}") 90 | return norm_logs 91 | -------------------------------------------------------------------------------- /common/modules/embedding/config.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from enum import Enum 3 | 4 | import tml.core.config as base_config 5 | from tml.optimizers.config import OptimizerConfig 6 | 7 | import pydantic 8 | 9 | 10 | class DataType(str, Enum): 11 | FP32 = "fp32" 12 | FP16 = "fp16" 13 | 14 | 15 | class EmbeddingSnapshot(base_config.BaseConfig): 16 | """Configuration for Embedding snapshot""" 17 | 18 | emb_name: str = pydantic.Field( 19 | ..., description="Name of the embedding table from the loaded snapshot" 20 | ) 21 | embedding_snapshot_uri: str = pydantic.Field( 22 | ..., description="Path to torchsnapshot of the embedding" 23 | ) 24 | 25 | 26 | class EmbeddingBagConfig(base_config.BaseConfig): 27 | """Configuration for EmbeddingBag.""" 28 | 29 | name: str = pydantic.Field(..., description="name of embedding bag") 30 | num_embeddings: int = pydantic.Field(..., description="size of embedding dictionary") 31 | embedding_dim: int = pydantic.Field(..., description="size of each embedding vector") 32 | pretrained: EmbeddingSnapshot = pydantic.Field(None, description="Snapshot properties") 33 | vocab: str = pydantic.Field( 34 | None, description="Directory to parquet files of mapping from entity ID to table index." 35 | ) 36 | # make sure to use an optimizer that matches: 37 | # https://github.com/pytorch/FBGEMM/blob/4c58137529d221390575e47e88d3c05ce65b66fd/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py#L15 38 | optimizer: OptimizerConfig 39 | data_type: DataType 40 | 41 | 42 | class LargeEmbeddingsConfig(base_config.BaseConfig): 43 | """Configuration for EmbeddingBagCollection. 44 | 45 | The tables listed in this config are gathered into a single torchrec EmbeddingBagCollection. 46 | """ 47 | 48 | tables: List[EmbeddingBagConfig] = pydantic.Field(..., description="list of embedding tables") 49 | tables_to_log: List[str] = pydantic.Field( 50 | None, description="list of embedding table names that we want to log during training" 51 | ) 52 | 53 | 54 | class Mode(str, Enum): 55 | """Job modes.""" 56 | 57 | TRAIN = "train" 58 | EVALUATE = "evaluate" 59 | INFERENCE = "inference" 60 | -------------------------------------------------------------------------------- /common/modules/embedding/embedding.py: -------------------------------------------------------------------------------- 1 | from tml.common.modules.embedding.config import LargeEmbeddingsConfig, DataType 2 | from tml.ml_logging.torch_logging import logging 3 | 4 | import torch 5 | from torch import nn 6 | import torchrec 7 | from torchrec.modules import embedding_configs 8 | from torchrec import EmbeddingBagConfig, EmbeddingBagCollection 9 | from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor 10 | import numpy as np 11 | 12 | 13 | class LargeEmbeddings(nn.Module): 14 | def __init__( 15 | self, 16 | large_embeddings_config: LargeEmbeddingsConfig, 17 | ): 18 | super().__init__() 19 | 20 | tables = [] 21 | for table in large_embeddings_config.tables: 22 | data_type = ( 23 | embedding_configs.DataType.FP32 24 | if (table.data_type == DataType.FP32) 25 | else embedding_configs.DataType.FP16 26 | ) 27 | 28 | tables.append( 29 | EmbeddingBagConfig( 30 | embedding_dim=table.embedding_dim, 31 | feature_names=[table.name], # restricted to 1 feature per table for now 32 | name=table.name, 33 | num_embeddings=table.num_embeddings, 34 | pooling=torchrec.PoolingType.SUM, 35 | data_type=data_type, 36 | ) 37 | ) 38 | 39 | self.ebc = EmbeddingBagCollection( 40 | device="meta", 41 | tables=tables, 42 | ) 43 | 44 | logging.info("********************** EBC named params are **********") 45 | logging.info(list(self.ebc.named_parameters())) 46 | 47 | # This hook is used to perform post-processing surgery 48 | # on large_embedding models to prep them for serving 49 | self.surgery_cut_point = torch.nn.Identity() 50 | 51 | def forward( 52 | self, 53 | sparse_features: KeyedJaggedTensor, 54 | ) -> KeyedTensor: 55 | pooled_embs = self.ebc(sparse_features) 56 | 57 | # a KeyedTensor 58 | return self.surgery_cut_point(pooled_embs) 59 | -------------------------------------------------------------------------------- /common/run_training.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import sys 4 | from typing import Optional 5 | 6 | from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined] 7 | from twitter.ml.tensorflow.experimental.distributed import utils 8 | 9 | import torch 10 | import torch.distributed.run 11 | 12 | 13 | def is_distributed_worker(): 14 | world_size = os.environ.get("WORLD_SIZE", None) 15 | rank = os.environ.get("RANK", None) 16 | return world_size is not None and rank is not None 17 | 18 | 19 | def maybe_run_training( 20 | train_fn, 21 | module_name, 22 | nproc_per_node: Optional[int] = None, 23 | num_nodes: Optional[int] = None, 24 | set_python_path_in_subprocess: bool = False, 25 | is_chief: Optional[bool] = False, 26 | **training_kwargs, 27 | ): 28 | """Wrapper function for single node, multi-GPU Pytorch training. 29 | 30 | If the necessary distributed Pytorch environment variables 31 | (WORLD_SIZE, RANK) have been set, then this function executes 32 | `train_fn(**training_kwargs)`. 33 | 34 | Otherwise, this function calls torchrun and points at the calling module 35 | `module_name`. After this call, the necessary environment variables are set 36 | and training will commence. 37 | 38 | Args: 39 | train_fn: The function that is responsible for training 40 | module_name: The name of the module that this function was called from; 41 | used to indicate torchrun entrypoint. 42 | nproc_per_node: Number of workers per node; supported values. 43 | num_nodes: Number of nodes, otherwise inferred from environment. 44 | is_chief: If process is running on chief. 45 | set_python_path_in_subprocess: A bool denoting whether to set PYTHONPATH. 46 | """ 47 | 48 | machines = utils.machine_from_env() 49 | if num_nodes is None: 50 | num_nodes = 1 51 | if machines.num_workers: 52 | num_nodes += machines.num_workers 53 | 54 | if is_distributed_worker(): 55 | # world_size, rank, etc are set; assuming any other env vars are set (checks to come) 56 | # start the actual training! 57 | train_fn(**training_kwargs) 58 | else: 59 | if nproc_per_node is None: 60 | if torch.cuda.is_available(): 61 | nproc_per_node = torch.cuda.device_count() 62 | else: 63 | nproc_per_node = machines.chief.num_accelerators 64 | 65 | # Rejoin all arguments to send back through torchrec 66 | # this is a temporary measure, will replace the os.system call 67 | # with torchrun API calls 68 | args = list(f"--{key}={val}" for key, val in training_kwargs.items()) 69 | 70 | cmd = [ 71 | "--nnodes", 72 | str(num_nodes), 73 | ] 74 | if nproc_per_node: 75 | cmd.extend(["--nproc_per_node", str(nproc_per_node)]) 76 | if num_nodes > 1: 77 | cluster_resolver = utils.cluster_resolver() 78 | backend_address = cluster_resolver.cluster_spec().task_address("chief", 0) 79 | cmd.extend( 80 | [ 81 | "--rdzv_backend", 82 | "c10d", 83 | "--rdzv_id", 84 | backend_address, 85 | ] 86 | ) 87 | # Set localhost on chief because of https://github.com/pytorch/pytorch/issues/79388 88 | if is_chief: 89 | cmd.extend(["--rdzv_endpoint", "localhost:2222"]) 90 | else: 91 | cmd.extend(["--rdzv_endpoint", backend_address]) 92 | else: 93 | cmd.append("--standalone") 94 | 95 | cmd.extend( 96 | [ 97 | str(module_name), 98 | *args, 99 | ] 100 | ) 101 | logging.info(f"""Distributed running with cmd: '{" ".join(cmd)}'""") 102 | 103 | # Call torchrun on this module; will spawn new processes and re-run this 104 | # function, eventually calling "train_fn". The following line sets the PYTHONPATH to accommodate 105 | # bazel stubbing for the main binary. 106 | if set_python_path_in_subprocess: 107 | subprocess.run(["torchrun"] + cmd, env={**os.environ, "PYTHONPATH": ":".join(sys.path)}) 108 | else: 109 | torch.distributed.run.main(cmd) 110 | -------------------------------------------------------------------------------- /common/test_device.py: -------------------------------------------------------------------------------- 1 | """Minimal test for device. 2 | 3 | Mostly a test that this can be imported properly even tho moved. 4 | """ 5 | from unittest.mock import patch 6 | 7 | import tml.common.device as device_utils 8 | 9 | 10 | def test_device(): 11 | with patch("tml.common.device.dist.init_process_group"): 12 | device = device_utils.setup_and_get_device(tf_ok=False) 13 | assert device.type == "cpu" 14 | -------------------------------------------------------------------------------- /common/testing_utils.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | import datetime 3 | import os 4 | from unittest.mock import patch 5 | 6 | import torch.distributed as dist 7 | from tml.ml_logging.torch_logging import logging 8 | 9 | 10 | MOCK_ENV = { 11 | "LOCAL_RANK": "0", 12 | "WORLD_SIZE": "1", 13 | "LOCAL_WORLD_SIZE": "1", 14 | "MASTER_ADDR": "localhost", 15 | "MASTER_PORT": "29501", 16 | "RANK": "0", 17 | } 18 | 19 | 20 | @contextmanager 21 | def mock_pg(): 22 | with patch.dict(os.environ, MOCK_ENV): 23 | try: 24 | dist.init_process_group( 25 | backend="gloo", 26 | timeout=datetime.timedelta(1), 27 | ) 28 | yield 29 | except: 30 | dist.destroy_process_group() 31 | raise 32 | finally: 33 | dist.destroy_process_group() 34 | -------------------------------------------------------------------------------- /common/utils.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import getpass 3 | import os 4 | import string 5 | from typing import Tuple, Type, TypeVar 6 | 7 | from tml.core.config import base_config 8 | 9 | import fsspec 10 | 11 | C = TypeVar("C", bound=base_config.BaseConfig) 12 | 13 | 14 | def _read_file(f): 15 | with fsspec.open(f) as f: 16 | return f.read() 17 | 18 | 19 | def setup_configuration( 20 | config_type: Type[C], 21 | yaml_path: str, 22 | substitute_env_variable: bool = False, 23 | ) -> Tuple[C, str]: 24 | """Resolves a config at a yaml path. 25 | 26 | Args: 27 | config_type: Pydantic config class to load. 28 | yaml_path: yaml path of the config file. 29 | substitute_env_variable: If True substitute string in the format $VAR or ${VAR} by their 30 | environment variable value whenever possible. If an environment variable doesn't exist, 31 | the string is left unchanged. 32 | 33 | Returns: 34 | The pydantic config object. 35 | """ 36 | 37 | def _substitute(s): 38 | if substitute_env_variable: 39 | return string.Template(s).safe_substitute(os.environ, USER=getpass.getuser()) 40 | return s 41 | 42 | assert config_type is not None, "can't use all_config without config_type" 43 | content = _substitute(yaml.safe_load(_read_file(yaml_path))) 44 | return config_type.parse_obj(content) 45 | -------------------------------------------------------------------------------- /common/wandb.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List 2 | 3 | import tml.core.config as base_config 4 | 5 | import pydantic 6 | 7 | 8 | class WandbConfig(base_config.BaseConfig): 9 | host: str = pydantic.Field( 10 | "https://https--wandb--prod--wandb.service.qus1.twitter.biz/", 11 | description="Host of Weights and Biases instance, passed to login.", 12 | ) 13 | key_path: str = pydantic.Field(description="Path to key file.") 14 | 15 | name: str = pydantic.Field(None, description="Name of the experiment, passed to init.") 16 | entity: str = pydantic.Field(None, description="Name of user/service account, passed to init.") 17 | project: str = pydantic.Field(None, description="Name of wandb project, passed to init.") 18 | tags: List[str] = pydantic.Field([], description="List of tags, passed to init.") 19 | notes: str = pydantic.Field(None, description="Notes, passed to init.") 20 | metadata: Dict[str, Any] = pydantic.Field(None, description="Additional metadata to log.") 21 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twitter/the-algorithm-ml/b85210863f7a94efded0ef5c5ccf4ff42767876c/core/__init__.py -------------------------------------------------------------------------------- /core/config/__init__.py: -------------------------------------------------------------------------------- 1 | from tml.core.config.base_config import BaseConfig 2 | from tml.core.config.config_load import load_config_from_yaml 3 | 4 | # Make mypy happy by explicitly rexporting the symbols intended for end user use. 5 | __all__ = ["BaseConfig", "load_config_from_yaml"] 6 | -------------------------------------------------------------------------------- /core/config/base_config.py: -------------------------------------------------------------------------------- 1 | """Base class for all config (forbids extra fields).""" 2 | 3 | import collections 4 | import functools 5 | import yaml 6 | 7 | import pydantic 8 | 9 | 10 | class BaseConfig(pydantic.BaseModel): 11 | """Base class for all derived config classes. 12 | 13 | This class provides some convenient functionality: 14 | - Disallows extra fields when constructing an object. User error 15 | should be reduced by exact arguments. 16 | - "one_of" fields. A subclass can group optional fields and enforce 17 | that only one of the fields be set. For example: 18 | 19 | ``` 20 | class ExampleConfig(BaseConfig): 21 | x: int = Field(None, one_of="group_1") 22 | y: int = Field(None, one_of="group_1") 23 | 24 | ExampleConfig(x=1) # ok 25 | ExampleConfig(y=1) # ok 26 | ExampleConfig(x=1, y=1) # throws error 27 | ``` 28 | """ 29 | 30 | class Config: 31 | """Forbids extras.""" 32 | 33 | extra = pydantic.Extra.forbid # noqa 34 | 35 | @classmethod 36 | @functools.lru_cache() 37 | def _field_data_map(cls, field_data_name): 38 | """Create a map of fields with provided the field data.""" 39 | schema = cls.schema() 40 | one_of = collections.defaultdict(list) 41 | for field, fdata in schema["properties"].items(): 42 | if field_data_name in fdata: 43 | one_of[fdata[field_data_name]].append(field) 44 | return one_of 45 | 46 | @pydantic.root_validator 47 | def _one_of_check(cls, values): 48 | """Validate that all 'one of' fields are appear exactly once.""" 49 | one_of_map = cls._field_data_map("one_of") 50 | for one_of, field_names in one_of_map.items(): 51 | if sum([values.get(n, None) is not None for n in field_names]) != 1: 52 | raise ValueError(f"Exactly one of {','.join(field_names)} required.") 53 | return values 54 | 55 | @pydantic.root_validator 56 | def _at_most_one_of_check(cls, values): 57 | """Validate that all 'at_most_one_of' fields appear at most once.""" 58 | at_most_one_of_map = cls._field_data_map("at_most_one_of") 59 | for one_of, field_names in at_most_one_of_map.items(): 60 | if sum([values.get(n, None) is not None for n in field_names]) > 1: 61 | raise ValueError(f"At most one of {','.join(field_names)} can be set.") 62 | return values 63 | 64 | def pretty_print(self) -> str: 65 | """Return a human legible (yaml) representation of the config useful for logging.""" 66 | return yaml.dump(self.dict()) 67 | -------------------------------------------------------------------------------- /core/config/base_config_test.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from tml.core.config import BaseConfig 4 | 5 | import pydantic 6 | 7 | 8 | class BaseConfigTest(TestCase): 9 | def test_extra_forbidden(self): 10 | class Config(BaseConfig): 11 | x: int 12 | 13 | Config(x=1) 14 | with self.assertRaises(pydantic.ValidationError): 15 | Config(x=1, y=2) 16 | 17 | def test_one_of(self): 18 | class Config(BaseConfig): 19 | x: int = pydantic.Field(None, one_of="f") 20 | y: int = pydantic.Field(None, one_of="f") 21 | 22 | with self.assertRaises(pydantic.ValidationError): 23 | Config() 24 | Config(x=1) 25 | Config(y=1) 26 | with self.assertRaises(pydantic.ValidationError): 27 | Config(x=1, y=3) 28 | 29 | def test_at_most_one_of(self): 30 | class Config(BaseConfig): 31 | x: int = pydantic.Field(None, at_most_one_of="f") 32 | y: str = pydantic.Field(None, at_most_one_of="f") 33 | 34 | Config() 35 | Config(x=1) 36 | Config(y="a") 37 | with self.assertRaises(pydantic.ValidationError): 38 | Config(x=1, y="a") 39 | -------------------------------------------------------------------------------- /core/config/config_load.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import string 3 | import getpass 4 | import os 5 | from typing import Type 6 | 7 | from tml.core.config.base_config import BaseConfig 8 | 9 | 10 | def load_config_from_yaml(config_type: Type[BaseConfig], yaml_path: str): 11 | """Recommend method to load a config file (a yaml file) and parse it. 12 | 13 | Because we have a shared filesystem the recommended route to running jobs it put modified config 14 | files with the desired parameters somewhere on the filesytem and run jobs pointing to them. 15 | """ 16 | 17 | def _substitute(s): 18 | return string.Template(s).safe_substitute(os.environ, USER=getpass.getuser()) 19 | 20 | with open(yaml_path, "r") as f: 21 | raw_contents = f.read() 22 | obj = yaml.safe_load(_substitute(raw_contents)) 23 | 24 | return config_type.parse_obj(obj) 25 | -------------------------------------------------------------------------------- /core/config/test_config_load.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from tml.core.config import BaseConfig, load_config_from_yaml 4 | 5 | import pydantic 6 | import getpass 7 | import pydantic 8 | 9 | 10 | class _PointlessConfig(BaseConfig): 11 | a: int 12 | user: str 13 | 14 | 15 | def test_load_config_from_yaml(tmp_path): 16 | yaml_path = tmp_path.joinpath("test.yaml").as_posix() 17 | with open(yaml_path, "w") as yaml_file: 18 | yaml_file.write("""a: 3\nuser: ${USER}\n""") 19 | 20 | pointless_config = load_config_from_yaml(_PointlessConfig, yaml_path) 21 | 22 | assert pointless_config.a == 3 23 | assert pointless_config.user == getpass.getuser() 24 | -------------------------------------------------------------------------------- /core/config/training.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional 2 | 3 | from tml.common.wandb import WandbConfig 4 | from tml.core.config import base_config 5 | from tml.projects.twhin.data.config import TwhinDataConfig 6 | from tml.projects.twhin.models.config import TwhinModelConfig 7 | 8 | import pydantic 9 | 10 | 11 | class RuntimeConfig(base_config.BaseConfig): 12 | wandb: WandbConfig = pydantic.Field(None) 13 | enable_tensorfloat32: bool = pydantic.Field( 14 | False, description="Use tensorfloat32 if on Ampere devices." 15 | ) 16 | enable_amp: bool = pydantic.Field(False, description="Enable automatic mixed precision.") 17 | 18 | 19 | class TrainingConfig(base_config.BaseConfig): 20 | save_dir: str = pydantic.Field("/tmp/model", description="Directory to save checkpoints.") 21 | num_train_steps: pydantic.PositiveInt = 10000 22 | initial_checkpoint_dir: str = pydantic.Field( 23 | None, description="Directory of initial checkpoints", at_most_one_of="initialization" 24 | ) 25 | checkpoint_every_n: pydantic.PositiveInt = 1000 26 | checkpoint_max_to_keep: pydantic.PositiveInt = pydantic.Field( 27 | None, description="Maximum number of checkpoints to keep. Defaults to keeping all." 28 | ) 29 | train_log_every_n: pydantic.PositiveInt = 1000 30 | num_eval_steps: int = pydantic.Field( 31 | 16384, description="Number of evaluation steps. If < 0 the entire dataset will be used." 32 | ) 33 | eval_log_every_n: pydantic.PositiveInt = 5000 34 | 35 | eval_timeout_in_s: pydantic.PositiveFloat = 60 * 60 36 | 37 | gradient_accumulation: int = pydantic.Field( 38 | None, description="Number of replica steps to accumulate gradients." 39 | ) 40 | num_epochs: pydantic.PositiveInt = 1 41 | -------------------------------------------------------------------------------- /core/debug_training_loop.py: -------------------------------------------------------------------------------- 1 | """This is a very limited feature training loop useful for interactive debugging. 2 | 3 | It is not intended for actual model tranining (it is not fast, doesn't compile the model). 4 | It does not support checkpointing. 5 | 6 | suggested use: 7 | 8 | from tml.core import debug_training_loop 9 | debug_training_loop.train(...) 10 | """ 11 | 12 | from typing import Iterable, Optional, Dict, Callable, List 13 | import torch 14 | from torch.optim.lr_scheduler import _LRScheduler 15 | import torchmetrics as tm 16 | 17 | from tml.ml_logging.torch_logging import logging 18 | 19 | 20 | def train( 21 | model: torch.nn.Module, 22 | optimizer: torch.optim.Optimizer, 23 | train_steps: int, 24 | dataset: Iterable, 25 | scheduler: _LRScheduler = None, 26 | # Accept any arguments (to be compatible with the real training loop) 27 | # but just ignore them. 28 | *args, 29 | **kwargs, 30 | ) -> None: 31 | 32 | logging.warning("Running debug training loop, don't use for model training.") 33 | 34 | data_iter = iter(dataset) 35 | for step in range(0, train_steps + 1): 36 | x = next(data_iter) 37 | optimizer.zero_grad() 38 | loss, outputs = model.forward(x) 39 | loss.backward() 40 | optimizer.step() 41 | 42 | if scheduler: 43 | scheduler.step() 44 | 45 | logging.info(f"Step {step} completed. Loss = {loss}") 46 | -------------------------------------------------------------------------------- /core/loss_type.py: -------------------------------------------------------------------------------- 1 | """Loss type enums.""" 2 | from enum import Enum 3 | 4 | 5 | class LossType(str, Enum): 6 | CROSS_ENTROPY = "cross_entropy" 7 | BCE_WITH_LOGITS = "bce_with_logits" 8 | -------------------------------------------------------------------------------- /core/losses.py: -------------------------------------------------------------------------------- 1 | """Loss functions -- including multi task ones.""" 2 | 3 | import typing 4 | 5 | from tml.core.loss_type import LossType 6 | from tml.ml_logging.torch_logging import logging 7 | 8 | import torch 9 | 10 | 11 | def _maybe_warn(reduction: str): 12 | """ 13 | Warning for reduction different than mean. 14 | """ 15 | if reduction != "mean": 16 | logging.warn( 17 | f"For the same global_batch_size, the gradient in DDP is guaranteed to be equal," 18 | f"to the gradient without DDP only for mean reduction. If you need this property for" 19 | f"the provided reduction {reduction}, it needs to be implemented." 20 | ) 21 | 22 | 23 | def build_loss( 24 | loss_type: LossType, 25 | reduction="mean", 26 | ): 27 | _maybe_warn(reduction) 28 | f = _LOSS_TYPE_TO_FUNCTION[loss_type] 29 | 30 | def loss_fn(logits, labels): 31 | return f(logits, labels.type_as(logits), reduction=reduction) 32 | 33 | return loss_fn 34 | 35 | 36 | def get_global_loss_detached(local_loss, reduction="mean"): 37 | """ 38 | Perform all_reduce to obtain the global loss function using the provided reduction. 39 | :param local_loss: The local loss of the current rank. 40 | :param reduction: The reduction to use for all_reduce. Should match the reduction used by DDP. 41 | :return: The reduced & detached global loss. 42 | """ 43 | if reduction != "mean": 44 | logging.warn( 45 | f"The reduction used in this function should be the same as the one used by " 46 | f"the DDP model. By default DDP uses mean, So ensure that DDP is appropriately" 47 | f"modified for reduction {reduction}." 48 | ) 49 | 50 | if reduction not in ["mean", "sum"]: 51 | raise ValueError(f"Reduction {reduction} is currently unsupported.") 52 | 53 | global_loss = local_loss.detach() 54 | 55 | if reduction == "mean": 56 | global_loss.div_(torch.distributed.get_world_size()) 57 | 58 | torch.distributed.all_reduce(global_loss) 59 | return global_loss 60 | 61 | 62 | def build_multi_task_loss( 63 | loss_type: LossType, 64 | tasks: typing.List[str], 65 | task_loss_reduction="mean", 66 | global_reduction="mean", 67 | pos_weights=None, 68 | ): 69 | _maybe_warn(global_reduction) 70 | _maybe_warn(task_loss_reduction) 71 | f = _LOSS_TYPE_TO_FUNCTION[loss_type] 72 | 73 | loss_reduction_fns = { 74 | "mean": torch.mean, 75 | "sum": torch.sum, 76 | "min": torch.min, 77 | "max": torch.max, 78 | "median": torch.median, 79 | } 80 | 81 | def loss_fn(logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor): 82 | if pos_weights is None: 83 | torch_weights = torch.ones([len(tasks)]) 84 | else: 85 | torch_weights = torch.tensor(pos_weights) 86 | 87 | losses = {} 88 | for task_idx, task in enumerate(tasks): 89 | task_logits = logits[:, task_idx] 90 | label = labels[:, task_idx].type_as(task_logits) 91 | 92 | loss = f( 93 | task_logits, 94 | label, 95 | reduction=task_loss_reduction, 96 | pos_weight=torch_weights[task_idx], 97 | weight=weights[:, task_idx], 98 | ) 99 | losses[f"loss/{task}"] = loss 100 | 101 | losses["loss"] = loss_reduction_fns[global_reduction](torch.stack(list(losses.values()))) 102 | return losses 103 | 104 | return loss_fn 105 | 106 | 107 | _LOSS_TYPE_TO_FUNCTION = { 108 | LossType.BCE_WITH_LOGITS: torch.nn.functional.binary_cross_entropy_with_logits 109 | } 110 | -------------------------------------------------------------------------------- /core/metric_mixin.py: -------------------------------------------------------------------------------- 1 | """ 2 | Mixin that requires a transform to munge output dictionary of tensors a 3 | model produces to a form that the torchmetrics.Metric.update expects. 4 | 5 | By unifying on our signature for `update`, we can also now use 6 | torchmetrics.MetricCollection which requires all metrics have 7 | the same call signature. 8 | 9 | To use, override this with a transform that munges `outputs` 10 | into a kwargs dict that the inherited metric.update accepts. 11 | 12 | Here are two examples of how to extend torchmetrics.SumMetric so that it accepts 13 | an output dictionary of tensors and munges it to what SumMetric expects (single `value`) 14 | for its update method. 15 | 16 | 1. Using as a mixin to inherit from or define a new metric class. 17 | 18 | class Count(MetricMixin, SumMetric): 19 | def transform(self, outputs): 20 | return {'value': 1} 21 | 22 | 2. Redefine an existing metric class. 23 | 24 | SumMetric = prepend_transform(SumMetric, lambda outputs: {'value': 1}) 25 | 26 | """ 27 | from abc import abstractmethod 28 | from typing import Callable, Dict, List 29 | 30 | from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined] 31 | 32 | import torch 33 | import torchmetrics 34 | 35 | 36 | class MetricMixin: 37 | @abstractmethod 38 | def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict: 39 | ... 40 | 41 | def update(self, outputs: Dict[str, torch.Tensor]): 42 | results = self.transform(outputs) 43 | # Do not try to update if any tensor is empty as a result of stratification. 44 | for value in results.values(): 45 | if torch.is_tensor(value) and not value.nelement(): 46 | return 47 | super().update(**results) 48 | 49 | 50 | class TaskMixin: 51 | def __init__(self, task_idx: int = -1, **kwargs): 52 | super().__init__(**kwargs) 53 | self._task_idx = task_idx 54 | 55 | 56 | class StratifyMixin: 57 | def __init__( 58 | self, 59 | stratifier=None, 60 | **kwargs, 61 | ): 62 | super().__init__(**kwargs) 63 | self._stratifier = stratifier 64 | 65 | def maybe_apply_stratification( 66 | self, outputs: Dict[str, torch.Tensor], value_names: List[str] 67 | ) -> Dict[str, torch.Tensor]: 68 | """Pick out examples with values for which the stratifier feature is equal to a specific stratifier indicator value.""" 69 | outputs = outputs.copy() 70 | if not self._stratifier: 71 | return outputs 72 | stratifiers = outputs.get("stratifiers") 73 | if not stratifiers: 74 | return outputs 75 | if stratifiers.get(self._stratifier.name) is None: 76 | return outputs 77 | 78 | mask = torch.flatten(outputs["stratifiers"][self._stratifier.name] == self._stratifier.value) 79 | target_slice = torch.squeeze(mask.nonzero(), -1) 80 | for value_name in value_names: 81 | target = outputs[value_name] 82 | outputs[value_name] = torch.index_select(target, 0, target_slice) 83 | return outputs 84 | 85 | 86 | def prepend_transform(base_metric: torchmetrics.Metric, transform: Callable): 87 | """Returns new class using MetricMixin and given base_metric. 88 | 89 | Functionally the same using inheritance, just saves some lines of code 90 | if no need for class attributes. 91 | 92 | """ 93 | 94 | def transform_method(_self, *args, **kwargs): 95 | return transform(*args, **kwargs) 96 | 97 | return type( 98 | base_metric.__name__, 99 | ( 100 | MetricMixin, 101 | base_metric, 102 | ), 103 | {"transform": transform_method}, 104 | ) 105 | -------------------------------------------------------------------------------- /core/metrics.py: -------------------------------------------------------------------------------- 1 | """Common metrics that also support multi task. 2 | 3 | We assume multi task models will output [task_idx, ...] predictions 4 | 5 | """ 6 | from typing import Any, Dict 7 | 8 | from tml.core.metric_mixin import MetricMixin, StratifyMixin, TaskMixin 9 | 10 | import torch 11 | import torchmetrics as tm 12 | 13 | 14 | def probs_and_labels( 15 | outputs: Dict[str, torch.Tensor], 16 | task_idx: int, 17 | ) -> Dict[str, torch.Tensor]: 18 | preds = outputs["probabilities"] 19 | target = outputs["labels"] 20 | if task_idx >= 0: 21 | preds = preds[:, task_idx] 22 | target = target[:, task_idx] 23 | return { 24 | "preds": preds, 25 | "target": target.int(), 26 | } 27 | 28 | 29 | class Count(StratifyMixin, TaskMixin, MetricMixin, tm.SumMetric): 30 | def transform(self, outputs): 31 | outputs = self.maybe_apply_stratification(outputs, ["labels"]) 32 | value = outputs["labels"] 33 | if self._task_idx >= 0: 34 | value = value[:, self._task_idx] 35 | return {"value": value} 36 | 37 | 38 | class Ctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): 39 | def transform(self, outputs): 40 | outputs = self.maybe_apply_stratification(outputs, ["labels"]) 41 | value = outputs["labels"] 42 | if self._task_idx >= 0: 43 | value = value[:, self._task_idx] 44 | return {"value": value} 45 | 46 | 47 | class Pctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): 48 | def transform(self, outputs): 49 | outputs = self.maybe_apply_stratification(outputs, ["probabilities"]) 50 | value = outputs["probabilities"] 51 | if self._task_idx >= 0: 52 | value = value[:, self._task_idx] 53 | return {"value": value} 54 | 55 | 56 | class Precision(StratifyMixin, TaskMixin, MetricMixin, tm.Precision): 57 | def transform(self, outputs): 58 | outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"]) 59 | return probs_and_labels(outputs, self._task_idx) 60 | 61 | 62 | class Recall(StratifyMixin, TaskMixin, MetricMixin, tm.Recall): 63 | def transform(self, outputs): 64 | outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"]) 65 | return probs_and_labels(outputs, self._task_idx) 66 | 67 | 68 | class TorchMetricsRocauc(StratifyMixin, TaskMixin, MetricMixin, tm.AUROC): 69 | def transform(self, outputs): 70 | outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"]) 71 | return probs_and_labels(outputs, self._task_idx) 72 | 73 | 74 | class Auc(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): 75 | """ 76 | Based on: 77 | https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/util.py#L420 78 | """ 79 | 80 | def __init__(self, num_samples, **kwargs): 81 | super().__init__(**kwargs) 82 | self.num_samples = num_samples 83 | 84 | def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]: 85 | scores, labels = outputs["logits"], outputs["labels"] 86 | pos_scores = scores[labels == 1] 87 | neg_scores = scores[labels == 0] 88 | result = { 89 | "value": pos_scores[torch.randint(len(pos_scores), (self.num_samples,))] 90 | > neg_scores[torch.randint(len(neg_scores), (self.num_samples,))] 91 | } 92 | return result 93 | 94 | 95 | class PosRanks(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): 96 | """ 97 | The ranks of all positives 98 | Based on: 99 | https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L73 100 | """ 101 | 102 | def __init__(self, **kwargs): 103 | super().__init__(**kwargs) 104 | 105 | def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]: 106 | scores, labels = outputs["logits"], outputs["labels"] 107 | _, sorted_indices = scores.sort(descending=True) 108 | pos_ranks = labels[sorted_indices].nonzero(as_tuple=True)[0] + 1 # all ranks start from 1 109 | result = {"value": pos_ranks} 110 | return result 111 | 112 | 113 | class ReciprocalRank(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): 114 | """ 115 | The reciprocal of the ranks of all 116 | Based on: 117 | https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L74 118 | """ 119 | 120 | def __init__(self, **kwargs): 121 | super().__init__(**kwargs) 122 | 123 | def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]: 124 | scores, labels = outputs["logits"], outputs["labels"] 125 | _, sorted_indices = scores.sort(descending=True) 126 | pos_ranks = labels[sorted_indices].nonzero(as_tuple=True)[0] + 1 # all ranks start from 1 127 | result = {"value": torch.div(torch.ones_like(pos_ranks), pos_ranks)} 128 | return result 129 | 130 | 131 | class HitAtK(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): 132 | """ 133 | The fraction of positives that rank in the top K among their negatives 134 | Note that this is basically precision@k 135 | Based on: 136 | https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L75 137 | """ 138 | 139 | def __init__(self, k: int, **kwargs): 140 | super().__init__(**kwargs) 141 | self.k = k 142 | 143 | def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]: 144 | scores, labels = outputs["logits"], outputs["labels"] 145 | _, sorted_indices = scores.sort(descending=True) 146 | pos_ranks = labels[sorted_indices].nonzero(as_tuple=True)[0] + 1 # all ranks start from 1 147 | result = {"value": (pos_ranks <= self.k).float()} 148 | return result 149 | -------------------------------------------------------------------------------- /core/test_metrics.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from tml.core import metrics as core_metrics 4 | from tml.core.metric_mixin import MetricMixin, prepend_transform 5 | 6 | import torch 7 | from torchmetrics import MaxMetric, MetricCollection, SumMetric 8 | 9 | 10 | @dataclass 11 | class MockStratifierConfig: 12 | name: str 13 | index: int 14 | value: int 15 | 16 | 17 | class Count(MetricMixin, SumMetric): 18 | def transform(self, outputs): 19 | return {"value": 1} 20 | 21 | 22 | Max = prepend_transform(MaxMetric, lambda outputs: {"value": outputs["value"]}) 23 | 24 | 25 | def test_count_metric(): 26 | num_examples = 123 27 | examples = [ 28 | {"stuff": 0}, 29 | ] * num_examples 30 | 31 | metric = Count() 32 | for outputs in examples: 33 | metric.update(outputs) 34 | 35 | assert metric.compute().item() == num_examples 36 | 37 | 38 | def test_collections(): 39 | max_metric = Max() 40 | count_metric = Count() 41 | metric = MetricCollection([max_metric, count_metric]) 42 | 43 | examples = [{"value": idx} for idx in range(123)] 44 | for outputs in examples: 45 | metric.update(outputs) 46 | 47 | assert metric.compute() == { 48 | max_metric.__class__.__name__: len(examples) - 1, 49 | count_metric.__class__.__name__: len(examples), 50 | } 51 | 52 | 53 | def test_task_dependent_ctr(): 54 | num_examples = 144 55 | batch_size = 1024 56 | outputs = [ 57 | { 58 | "stuff": 0, 59 | "labels": torch.arange(0, 6).repeat(batch_size, 1), 60 | } 61 | for idx in range(num_examples) 62 | ] 63 | 64 | for task_idx in range(5): 65 | metric = core_metrics.Ctr(task_idx=task_idx) 66 | for output in outputs: 67 | metric.update(output) 68 | assert metric.compute().item() == task_idx 69 | 70 | 71 | def test_stratified_ctr(): 72 | outputs = [ 73 | { 74 | "stuff": 0, 75 | # [bsz, tasks] 76 | "labels": torch.tensor( 77 | [ 78 | [0, 1, 2, 3], 79 | [1, 2, 3, 4], 80 | [2, 3, 4, 0], 81 | ] 82 | ), 83 | "stratifiers": { 84 | # [bsz] 85 | "level": torch.tensor( 86 | [9, 0, 9], 87 | ), 88 | }, 89 | } 90 | ] 91 | 92 | stratifier = MockStratifierConfig(name="level", index=2, value=9) 93 | for task_idx in range(5): 94 | metric = core_metrics.Ctr(task_idx=1, stratifier=stratifier) 95 | for output in outputs: 96 | metric.update(output) 97 | # From the dataset of: 98 | # [ 99 | # [0, 1, 2, 3], 100 | # [1, 2, 3, 4], 101 | # [2, 3, 4, 0], 102 | # ] 103 | # we pick out 104 | # [ 105 | # [0, 1, 2, 3], 106 | # [2, 3, 4, 0], 107 | # ] 108 | # and with Ctr task_idx, we pick out 109 | # [ 110 | # [1,], 111 | # [3,], 112 | # ] 113 | assert metric.compute().item() == (1 + 3) / 2 114 | 115 | 116 | def test_auc(): 117 | num_samples = 10000 118 | metric = core_metrics.Auc(num_samples) 119 | target = torch.tensor([0, 0, 1, 1, 1]) 120 | preds_correct = torch.tensor([-1.0, -1.0, 1.0, 1.0, 1.0]) 121 | outputs_correct = {"logits": preds_correct, "labels": target} 122 | preds_bad = torch.tensor([1.0, 1.0, -1.0, -1.0, -1.0]) 123 | outputs_bad = {"logits": preds_bad, "labels": target} 124 | 125 | metric.update(outputs_correct) 126 | assert metric.compute().item() == 1.0 127 | 128 | metric.reset() 129 | metric.update(outputs_bad) 130 | assert metric.compute().item() == 0.0 131 | 132 | 133 | def test_pos_rank(): 134 | metric = core_metrics.PosRanks() 135 | target = torch.tensor([0, 0, 1, 1, 1]) 136 | preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5]) 137 | outputs_correct = {"logits": preds_correct, "labels": target} 138 | preds_bad = torch.tensor([1.0, 1.0, -1.5, -1.0, -0.5]) 139 | outputs_bad = {"logits": preds_bad, "labels": target} 140 | 141 | metric.update(outputs_correct) 142 | assert metric.compute().item() == 2.0 143 | 144 | metric.reset() 145 | metric.update(outputs_bad) 146 | assert metric.compute().item() == 4.0 147 | 148 | 149 | def test_reciprocal_rank(): 150 | metric = core_metrics.ReciprocalRank() 151 | target = torch.tensor([0, 0, 1, 1, 1]) 152 | preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5]) 153 | outputs_correct = {"logits": preds_correct, "labels": target} 154 | preds_bad = torch.tensor([1.0, 1.0, -1.5, -1.0, -0.5]) 155 | outputs_bad = {"logits": preds_bad, "labels": target} 156 | 157 | metric.update(outputs_correct) 158 | assert abs(metric.compute().item() - 0.6111) < 0.001 159 | 160 | metric.reset() 161 | metric.update(outputs_bad) 162 | assert abs(metric.compute().item() == 0.2611) < 0.001 163 | 164 | 165 | def test_hit_k(): 166 | hit1_metric = core_metrics.HitAtK(1) 167 | target = torch.tensor([0, 0, 1, 1, 1]) 168 | preds_correct = torch.tensor([-1.0, 1.0, 0.5, -0.1, 1.5]) 169 | outputs_correct = {"logits": preds_correct, "labels": target} 170 | preds_bad = torch.tensor([1.0, 1.0, -1.5, -1.0, -0.5]) 171 | outputs_bad = {"logits": preds_bad, "labels": target} 172 | 173 | hit1_metric.update(outputs_correct) 174 | assert abs(hit1_metric.compute().item() - 0.3333) < 0.0001 175 | 176 | hit1_metric.reset() 177 | hit1_metric.update(outputs_bad) 178 | 179 | assert hit1_metric.compute().item() == 0 180 | 181 | hit3_metric = core_metrics.HitAtK(3) 182 | hit3_metric.update(outputs_correct) 183 | assert (hit3_metric.compute().item() - 0.66666) < 0.0001 184 | 185 | hit3_metric.reset() 186 | hit3_metric.update(outputs_bad) 187 | assert abs(hit3_metric.compute().item() - 0.3333) < 0.0001 188 | -------------------------------------------------------------------------------- /core/test_train_pipeline.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Tuple 3 | 4 | from tml.common.batch import DataclassBatch 5 | from tml.common.testing_utils import mock_pg 6 | from tml.core import train_pipeline 7 | 8 | import torch 9 | from torchrec.distributed import DistributedModelParallel 10 | 11 | 12 | @dataclass 13 | class MockDataclassBatch(DataclassBatch): 14 | continuous_features: torch.Tensor 15 | labels: torch.Tensor 16 | 17 | 18 | class MockModule(torch.nn.Module): 19 | def __init__(self) -> None: 20 | super().__init__() 21 | self.model = torch.nn.Linear(10, 1) 22 | self.loss_fn = torch.nn.BCEWithLogitsLoss() 23 | 24 | def forward(self, batch: MockDataclassBatch) -> Tuple[torch.Tensor, torch.Tensor]: 25 | pred = self.model(batch.continuous_features) 26 | loss = self.loss_fn(pred, batch.labels) 27 | return (loss, pred) 28 | 29 | 30 | def create_batch(bsz: int): 31 | return MockDataclassBatch( 32 | continuous_features=torch.rand(bsz, 10).float(), 33 | labels=torch.bernoulli(torch.empty(bsz, 1).uniform_(0, 1)).float(), 34 | ) 35 | 36 | 37 | def test_sparse_pipeline(): 38 | device = torch.device("cpu") 39 | model = MockModule().to(device) 40 | 41 | steps = 8 42 | example = create_batch(1) 43 | dataloader = iter(example for _ in range(steps + 2)) 44 | 45 | results = [] 46 | with mock_pg(): 47 | d_model = DistributedModelParallel(model) 48 | pipeline = train_pipeline.TrainPipelineSparseDist( 49 | model=d_model, 50 | optimizer=torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9), 51 | device=device, 52 | grad_accum=2, 53 | ) 54 | for _ in range(steps): 55 | results.append(pipeline.progress(dataloader)) 56 | 57 | results = [elem.detach().numpy() for elem in results] 58 | # Check gradients are accumulated, i.e. results do not change for every 0th and 1th. 59 | for first, second in zip(results[::2], results[1::2]): 60 | assert first == second, results 61 | 62 | # Check we do update gradients, i.e. results do change for every 1th and 2nd. 63 | for first, second in zip(results[1::2], results[2::2]): 64 | assert first != second, results 65 | 66 | 67 | def test_amp(): 68 | device = torch.device("cpu") 69 | model = MockModule().to(device) 70 | 71 | steps = 8 72 | example = create_batch(1) 73 | dataloader = iter(example for _ in range(steps + 2)) 74 | 75 | results = [] 76 | with mock_pg(): 77 | d_model = DistributedModelParallel(model) 78 | pipeline = train_pipeline.TrainPipelineSparseDist( 79 | model=d_model, 80 | optimizer=torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9), 81 | device=device, 82 | enable_amp=True, 83 | # Not supported on CPU. 84 | enable_grad_scaling=False, 85 | ) 86 | for _ in range(steps): 87 | results.append(pipeline.progress(dataloader)) 88 | 89 | results = [elem.detach() for elem in results] 90 | for value in results: 91 | assert value.dtype == torch.bfloat16 92 | -------------------------------------------------------------------------------- /images/init_venv.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | 3 | if [[ "$(uname)" == "Darwin" ]]; then 4 | echo "Only supported on Linux." 5 | exit 1 6 | fi 7 | 8 | # You may need to point this to a version of python 3.10 9 | PYTHONBIN="/opt/ee/python/3.10/bin/python3.10" 10 | echo Using "PYTHONBIN=$PYTHONBIN" 11 | 12 | # Put venv in tmp, these things are not made to last, just rebuild. 13 | VENV_PATH="$HOME/tml_venv" 14 | rm -rf "$VENV_PATH" 15 | "$PYTHONBIN" -m venv "$VENV_PATH" 16 | 17 | # shellcheck source=/dev/null 18 | . "$VENV_PATH/bin/activate" 19 | 20 | pip --require-virtual install -U pip 21 | pip --require-virtualenv install --no-deps -r images/requirements.txt 22 | 23 | ln -s "$(pwd)" "$VENV_PATH/lib/python3.10/site-packages/tml" 24 | 25 | echo "Now run source ${VENV_PATH}/bin/activate" to get going. 26 | -------------------------------------------------------------------------------- /images/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.4.0 2 | aiofiles==22.1.0 3 | aiohttp==3.8.3 4 | aiosignal==1.3.1 5 | appdirs==1.4.4 6 | arrow==1.2.3 7 | asttokens==2.2.1 8 | astunparse==1.6.3 9 | async-timeout==4.0.2 10 | attrs==22.1.0 11 | backcall==0.2.0 12 | black==22.6.0 13 | cachetools==5.3.0 14 | cblack==22.6.0 15 | certifi==2022.12.7 16 | cfgv==3.3.1 17 | charset-normalizer==2.1.1 18 | click==8.1.3 19 | cmake==3.25.0 20 | Cython==0.29.32 21 | decorator==5.1.1 22 | distlib==0.3.6 23 | distro==1.8.0 24 | dm-tree==0.1.6 25 | docker==6.0.1 26 | docker-pycreds==0.4.0 27 | docstring-parser==0.8.1 28 | exceptiongroup==1.1.0 29 | executing==1.2.0 30 | fbgemm-gpu-cpu==0.3.2 31 | filelock==3.8.2 32 | fire==0.5.0 33 | flatbuffers==1.12 34 | frozenlist==1.3.3 35 | fsspec==2022.11.0 36 | gast==0.4.0 37 | gcsfs==2022.11.0 38 | gitdb==4.0.10 39 | GitPython==3.1.31 40 | google-api-core==2.8.2 41 | google-auth==2.16.0 42 | google-auth-oauthlib==0.4.6 43 | google-cloud-core==2.3.2 44 | google-cloud-storage==2.7.0 45 | google-crc32c==1.5.0 46 | google-pasta==0.2.0 47 | google-resumable-media==2.4.1 48 | googleapis-common-protos==1.56.4 49 | grpcio==1.51.1 50 | h5py==3.8.0 51 | hypothesis==6.61.0 52 | identify==2.5.17 53 | idna==3.4 54 | importlib-metadata==6.0.0 55 | iniconfig==2.0.0 56 | iopath==0.1.10 57 | ipdb==0.13.11 58 | ipython==8.10.0 59 | jedi==0.18.2 60 | Jinja2==3.1.2 61 | keras==2.9.0 62 | Keras-Preprocessing==1.1.2 63 | libclang==15.0.6.1 64 | libcst==0.4.9 65 | Markdown==3.4.1 66 | MarkupSafe==2.1.1 67 | matplotlib-inline==0.1.6 68 | moreorless==0.4.0 69 | multidict==6.0.4 70 | mypy==1.0.1 71 | mypy-extensions==0.4.3 72 | nest-asyncio==1.5.6 73 | ninja==1.11.1 74 | nodeenv==1.7.0 75 | numpy==1.22.0 76 | nvidia-cublas-cu11==11.10.3.66 77 | nvidia-cuda-nvrtc-cu11==11.7.99 78 | nvidia-cuda-runtime-cu11==11.7.99 79 | nvidia-cudnn-cu11==8.5.0.96 80 | oauthlib==3.2.2 81 | opt-einsum==3.3.0 82 | packaging==22.0 83 | pandas==1.5.3 84 | parso==0.8.3 85 | pathspec==0.11.0 86 | pathtools==0.1.2 87 | pexpect==4.8.0 88 | pickleshare==0.7.5 89 | platformdirs==3.0.0 90 | pluggy==1.0.0 91 | portalocker==2.6.0 92 | portpicker==1.5.2 93 | pre-commit==3.0.4 94 | prompt-toolkit==3.0.36 95 | protobuf==3.20.2 96 | psutil==5.9.4 97 | ptyprocess==0.7.0 98 | pure-eval==0.2.2 99 | pyarrow==10.0.1 100 | pyasn1==0.4.8 101 | pyasn1-modules==0.2.8 102 | pydantic==1.9.0 103 | pyDeprecate==0.3.2 104 | Pygments==2.14.0 105 | pyparsing==3.0.9 106 | pyre-extensions==0.0.27 107 | pytest==7.2.1 108 | pytest-mypy==0.10.3 109 | python-dateutil==2.8.2 110 | pytz==2022.6 111 | PyYAML==6.0.0 112 | requests==2.28.1 113 | requests-oauthlib==1.3.1 114 | rsa==4.9 115 | scikit-build==0.16.3 116 | sentry-sdk==1.16.0 117 | setproctitle==1.3.2 118 | six==1.16.0 119 | smmap==5.0.0 120 | sortedcontainers==2.4.0 121 | stack-data==0.6.2 122 | stdlibs==2022.10.9 123 | tabulate==0.9.0 124 | tensorboard==2.9.0 125 | tensorboard-data-server==0.6.1 126 | tensorboard-plugin-wit==1.8.1 127 | tensorflow==2.9.3 128 | tensorflow-estimator==2.9.0 129 | tensorflow-io-gcs-filesystem==0.30.0 130 | termcolor==2.2.0 131 | toml==0.10.2 132 | tomli==2.0.1 133 | torch==1.13.1 134 | torchmetrics==0.11.0 135 | torchrec==0.3.2 136 | torchsnapshot==0.1.0 137 | torchx==0.3.0 138 | tqdm==4.64.1 139 | trailrunner==1.2.1 140 | traitlets==5.9.0 141 | typing-inspect==0.8.0 142 | typing_extensions==4.4.0 143 | urllib3==1.26.13 144 | usort==1.0.5 145 | virtualenv==20.19.0 146 | wandb==0.13.11 147 | wcwidth==0.2.6 148 | websocket-client==1.4.2 149 | Werkzeug==2.2.3 150 | wrapt==1.14.1 151 | yarl==1.8.2 152 | zipp==3.12.1 153 | -------------------------------------------------------------------------------- /machines/environment.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import List 4 | 5 | 6 | KF_DDS_PORT: int = 5050 7 | SLURM_DDS_PORT: int = 5051 8 | FLIGHT_SERVER_PORT: int = 2222 9 | 10 | 11 | def on_kf(): 12 | return "SPEC_TYPE" in os.environ 13 | 14 | 15 | def has_readers(): 16 | if on_kf(): 17 | machines_config_env = json.loads(os.environ["MACHINES_CONFIG"]) 18 | return machines_config_env["dataset_worker"] is not None 19 | return os.environ.get("HAS_READERS", "False") == "True" 20 | 21 | 22 | def get_task_type(): 23 | if on_kf(): 24 | return os.environ["SPEC_TYPE"] 25 | return os.environ["TASK_TYPE"] 26 | 27 | 28 | def is_chief() -> bool: 29 | return get_task_type() == "chief" 30 | 31 | 32 | def is_reader() -> bool: 33 | return get_task_type() == "datasetworker" 34 | 35 | 36 | def is_dispatcher() -> bool: 37 | return get_task_type() == "datasetdispatcher" 38 | 39 | 40 | def get_task_index(): 41 | if on_kf(): 42 | pod_name = os.environ["MY_POD_NAME"] 43 | return int(pod_name.split("-")[-1]) 44 | else: 45 | raise NotImplementedError 46 | 47 | 48 | def get_reader_port(): 49 | if on_kf(): 50 | return KF_DDS_PORT 51 | return SLURM_DDS_PORT 52 | 53 | 54 | def get_dds(): 55 | if not has_readers(): 56 | return None 57 | dispatcher_address = get_dds_dispatcher_address() 58 | if dispatcher_address: 59 | return f"grpc://{dispatcher_address}" 60 | else: 61 | raise ValueError("Job does not have DDS.") 62 | 63 | 64 | def get_dds_dispatcher_address(): 65 | if not has_readers(): 66 | return None 67 | if on_kf(): 68 | job_name = os.environ["JOB_NAME"] 69 | dds_host = f"{job_name}-datasetdispatcher-0" 70 | else: 71 | dds_host = os.environ["SLURM_JOB_NODELIST_HET_GROUP_0"] 72 | return f"{dds_host}:{get_reader_port()}" 73 | 74 | 75 | def get_dds_worker_address(): 76 | if not has_readers(): 77 | return None 78 | if on_kf(): 79 | job_name = os.environ["JOB_NAME"] 80 | task_index = get_task_index() 81 | return f"{job_name}-datasetworker-{task_index}:{get_reader_port()}" 82 | else: 83 | node = os.environ["SLURMD_NODENAME"] 84 | return f"{node}:{get_reader_port()}" 85 | 86 | 87 | def get_num_readers(): 88 | if not has_readers(): 89 | return 0 90 | if on_kf(): 91 | machines_config_env = json.loads(os.environ["MACHINES_CONFIG"]) 92 | return int(machines_config_env["num_dataset_workers"] or 0) 93 | return len(os.environ["SLURM_JOB_NODELIST_HET_GROUP_1"].split(",")) 94 | 95 | 96 | def get_flight_server_addresses(): 97 | if on_kf(): 98 | job_name = os.environ["JOB_NAME"] 99 | return [ 100 | f"grpc://{job_name}-datasetworker-{task_index}:{FLIGHT_SERVER_PORT}" 101 | for task_index in range(get_num_readers()) 102 | ] 103 | else: 104 | raise NotImplementedError 105 | 106 | 107 | def get_dds_journaling_dir(): 108 | return os.environ.get("DATASET_JOURNALING_DIR", None) 109 | -------------------------------------------------------------------------------- /machines/get_env.py: -------------------------------------------------------------------------------- 1 | import tml.machines.environment as env 2 | 3 | from absl import app, flags 4 | 5 | 6 | FLAGS = flags.FLAGS 7 | flags.DEFINE_string("property", None, "Which property of the current environment to fetch.") 8 | 9 | 10 | def main(argv): 11 | if FLAGS.property == "using_dds": 12 | print(f"{env.has_readers()}", flush=True) 13 | if FLAGS.property == "has_readers": 14 | print(f"{env.has_readers()}", flush=True) 15 | elif FLAGS.property == "get_task_type": 16 | print(f"{env.get_task_type()}", flush=True) 17 | elif FLAGS.property == "is_datasetworker": 18 | print(f"{env.is_reader()}", flush=True) 19 | elif FLAGS.property == "is_dds_dispatcher": 20 | print(f"{env.is_dispatcher()}", flush=True) 21 | elif FLAGS.property == "get_task_index": 22 | print(f"{env.get_task_index()}", flush=True) 23 | elif FLAGS.property == "get_dataset_service": 24 | print(f"{env.get_dds()}", flush=True) 25 | elif FLAGS.property == "get_dds_dispatcher_address": 26 | print(f"{env.get_dds_dispatcher_address()}", flush=True) 27 | elif FLAGS.property == "get_dds_worker_address": 28 | print(f"{env.get_dds_worker_address()}", flush=True) 29 | elif FLAGS.property == "get_dds_port": 30 | print(f"{env.get_reader_port()}", flush=True) 31 | elif FLAGS.property == "get_dds_journaling_dir": 32 | print(f"{env.get_dds_journaling_dir()}", flush=True) 33 | elif FLAGS.property == "should_start_dds": 34 | print(env.is_reader() or env.is_dispatcher(), flush=True) 35 | 36 | 37 | if __name__ == "__main__": 38 | app.run(main) 39 | -------------------------------------------------------------------------------- /machines/is_venv.py: -------------------------------------------------------------------------------- 1 | """This is intended to be run as a module. 2 | e.g. python -m tml.machines.is_venv 3 | 4 | Exits with 0 ii running in venv, otherwise 1. 5 | """ 6 | 7 | import sys 8 | import logging 9 | 10 | 11 | def is_venv(): 12 | # See https://stackoverflow.com/questions/1871549/determine-if-python-is-running-inside-virtualenv 13 | return sys.base_prefix != sys.prefix 14 | 15 | 16 | def _main(): 17 | if is_venv(): 18 | logging.info("In venv %s", sys.prefix) 19 | sys.exit(0) 20 | else: 21 | logging.error("Not in venv") 22 | sys.exit(1) 23 | 24 | 25 | if __name__ == "__main__": 26 | _main() 27 | -------------------------------------------------------------------------------- /machines/list_ops.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple str.split() parsing of input string 3 | 4 | usage example: 5 | python list_ops.py --input_list=$INPUT [--sep=","] [--op=] [--elem=$INDEX] 6 | 7 | Args: 8 | - input_list: input string 9 | - sep (default ","): separator string 10 | - elem (default 0): integer index 11 | - op (default "select"): either `len` or `select` 12 | - len: prints len(input_list.split(sep)) 13 | - select: prints input_list.split(sep)[elem] 14 | 15 | Typical usage would be in a bash script, e.g.: 16 | 17 | LIST_LEN=$(python list_ops.py --input_list=$INPUT --op=len) 18 | 19 | """ 20 | import tml.machines.environment as env 21 | 22 | from absl import app, flags 23 | 24 | 25 | FLAGS = flags.FLAGS 26 | flags.DEFINE_string("input_list", None, "string to parse as list") 27 | flags.DEFINE_integer("elem", 0, "which element to take") 28 | flags.DEFINE_string("sep", ",", "separator") 29 | flags.DEFINE_string("op", "select", "operation to do") 30 | 31 | 32 | def main(argv): 33 | split_list = FLAGS.input_list.split(FLAGS.sep) 34 | if FLAGS.op == "select": 35 | print(split_list[FLAGS.elem], flush=True) 36 | elif FLAGS.op == "len": 37 | print(len(split_list), flush=True) 38 | else: 39 | raise ValueError(f"operation {FLAGS.op} not recognized.") 40 | 41 | 42 | if __name__ == "__main__": 43 | app.run(main) 44 | -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .aggregation import StableMean # noqa 2 | from .auroc import AUROCWithMWU # noqa 3 | from .rce import NRCE, RCE # noqa 4 | -------------------------------------------------------------------------------- /metrics/aggregation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains aggregation metrics. 3 | """ 4 | from typing import Tuple, Union 5 | 6 | import torch 7 | import torchmetrics 8 | 9 | 10 | def update_mean( 11 | current_mean: torch.Tensor, 12 | current_weight_sum: torch.Tensor, 13 | value: torch.Tensor, 14 | weight: torch.Tensor, 15 | ) -> Tuple[torch.Tensor, torch.Tensor]: 16 | """ 17 | Update the mean according to Welford formula: 18 | https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version. 19 | See also https://nullbuffer.com/articles/welford_algorithm.html for more information. 20 | Args: 21 | current_mean: The value of the current accumulated mean. 22 | current_weight_sum: The current weighted sum. 23 | value: The new value that needs to be added to get a new mean. 24 | weight: The weights for the new value. 25 | 26 | Returns: The updated mean and updated weighted sum. 27 | 28 | """ 29 | weight = torch.broadcast_to(weight, value.shape) 30 | 31 | # Avoiding (on purpose) in-place operation when using += in case 32 | # current_mean and current_weight_sum share the same storage 33 | current_weight_sum = current_weight_sum + torch.sum(weight) 34 | current_mean = current_mean + torch.sum((weight / current_weight_sum) * (value - current_mean)) 35 | return current_mean, current_weight_sum 36 | 37 | 38 | def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor: 39 | """ 40 | Merge the state from multiple workers. 41 | Args: 42 | state: A tensor with the first dimension indicating workers. 43 | 44 | Returns: The accumulated mean from all workers. 45 | 46 | """ 47 | mean, weight_sum = update_mean( 48 | current_mean=torch.as_tensor(0.0, dtype=state.dtype, device=state.device), 49 | current_weight_sum=torch.as_tensor(0.0, dtype=state.dtype, device=state.device), 50 | value=state[:, 0], 51 | weight=state[:, 1], 52 | ) 53 | return torch.stack([mean, weight_sum]) 54 | 55 | 56 | class StableMean(torchmetrics.Metric): 57 | """ 58 | This implements a numerical stable mean metrics computation using Welford algorithm according to 59 | https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version. 60 | For example when using float32, the algorithm will give a valid output even if the "sum" is larger 61 | than the maximum float32 as far as the mean is within the limit of float32. 62 | See also https://nullbuffer.com/articles/welford_algorithm.html for more information. 63 | """ 64 | 65 | def __init__(self, **kwargs): 66 | """ 67 | Args: 68 | **kwargs: Additional parameters supported by all torchmetrics.Metric. 69 | """ 70 | super().__init__(**kwargs) 71 | self.add_state( 72 | "mean_and_weight_sum", 73 | default=torch.zeros(2), 74 | dist_reduce_fx=stable_mean_dist_reduce_fn, 75 | ) 76 | 77 | def update(self, value: torch.Tensor, weight: Union[float, torch.Tensor] = 1.0) -> None: 78 | """ 79 | Update the current mean. 80 | Args: 81 | value: Value to update the mean with. 82 | weight: weight to use. Shape should be broadcastable to that of value. 83 | """ 84 | mean, weight_sum = self.mean_and_weight_sum[0], self.mean_and_weight_sum[1] 85 | 86 | if not isinstance(weight, torch.Tensor): 87 | weight = torch.as_tensor(weight, dtype=value.dtype, device=value.device) 88 | 89 | self.mean_and_weight_sum[0], self.mean_and_weight_sum[1] = update_mean( 90 | mean, weight_sum, value, torch.as_tensor(weight) 91 | ) 92 | 93 | def compute(self) -> torch.Tensor: 94 | """ 95 | Compute and return the accumulated mean. 96 | """ 97 | return self.mean_and_weight_sum[0] 98 | -------------------------------------------------------------------------------- /metrics/auroc.py: -------------------------------------------------------------------------------- 1 | """ 2 | AUROC metrics. 3 | """ 4 | from typing import Union 5 | 6 | from tml.ml_logging.torch_logging import logging 7 | 8 | import torch 9 | import torchmetrics 10 | from torchmetrics.utilities.data import dim_zero_cat 11 | 12 | 13 | def _compute_helper( 14 | predictions: torch.Tensor, 15 | target: torch.Tensor, 16 | weights: torch.Tensor, 17 | max_positive_negative_weighted_sum: torch.Tensor, 18 | min_positive_negative_weighted_sum: torch.Tensor, 19 | equal_predictions_as_incorrect: bool, 20 | ) -> torch.Tensor: 21 | """ 22 | Compute AUROC. 23 | Args: 24 | predictions: The predictions probabilities. 25 | target: The target. 26 | weights: The sample weights to assign to each sample in the batch. 27 | max_positive_negative_weighted_sum: The sum of the weights for the positive labels. 28 | min_positive_negative_weighted_sum: 29 | equal_predictions_as_incorrect: For positive & negative labels having identical scores, 30 | we assume that they are correct prediction (i.e weight = 1) when ths is False. Otherwise, 31 | we assume that they are correct prediction (i.e weight = 0). 32 | """ 33 | dim = 0 34 | 35 | # Sort predictions based on key (score, true_label). The order is ascending for score. 36 | # For true_label, order is ascending if equal_predictions_as_incorrect is True; 37 | # otherwise it is descending. 38 | target_order = torch.argsort(target, dim=dim, descending=equal_predictions_as_incorrect) 39 | score_order = torch.sort(torch.gather(predictions, dim, target_order), stable=True, dim=dim)[1] 40 | score_order = torch.gather(target_order, dim, score_order) 41 | sorted_target = torch.gather(target, dim, score_order) 42 | sorted_weights = torch.gather(weights, dim, score_order) 43 | 44 | negatives_from_left = torch.cumsum((1.0 - sorted_target) * sorted_weights, 0) 45 | 46 | numerator = torch.sum( 47 | sorted_weights * (sorted_target * negatives_from_left / max_positive_negative_weighted_sum) 48 | ) 49 | 50 | return numerator / min_positive_negative_weighted_sum 51 | 52 | 53 | class AUROCWithMWU(torchmetrics.Metric): 54 | """ 55 | AUROC using Mann-Whitney U-test. 56 | See https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve. 57 | 58 | This AUROC implementation is well suited to (non-zero) low-CTR. In particular it will return 59 | the correct AUROC even if the predicted probabilities are all close to 0. 60 | Currently only support binary classification. 61 | """ 62 | 63 | def __init__(self, label_threshold: float = 0.5, raise_missing_class: bool = False, **kwargs): 64 | """ 65 | 66 | Args: 67 | label_threshold: Labels strictly above this threshold are considered positive labels, 68 | otherwise, they are considered negative. 69 | raise_missing_class: If True, an error will be raise if negative or positive class is missing. 70 | Otherwise, we will simply log a warning. 71 | **kwargs: Additional parameters supported by all torchmetrics.Metric. 72 | """ 73 | super().__init__(**kwargs) 74 | self.add_state("predictions", default=[], dist_reduce_fx="cat") 75 | self.add_state("target", default=[], dist_reduce_fx="cat") 76 | self.add_state("weights", default=[], dist_reduce_fx="cat") 77 | 78 | self.label_threshold = label_threshold 79 | self.raise_missing_class = raise_missing_class 80 | 81 | def update( 82 | self, 83 | predictions: torch.Tensor, 84 | target: torch.Tensor, 85 | weight: Union[float, torch.Tensor] = 1.0, 86 | ) -> None: 87 | """ 88 | Update the current auroc. 89 | Args: 90 | predictions: Predicted values, 1D Tensor or 2D Tensor of shape batch_size x 1. 91 | target: Ground truth. Must have same shape as predictions. 92 | weight: The weight to use for the predicted values. Shape should be 93 | broadcastable to that of predictions. 94 | """ 95 | self.predictions.append(predictions) 96 | self.target.append(target) 97 | if not isinstance(weight, torch.Tensor): 98 | weight = torch.as_tensor(weight, dtype=predictions.dtype, device=target.device) 99 | self.weights.append(torch.broadcast_to(weight, predictions.size())) 100 | 101 | def compute(self) -> torch.Tensor: 102 | """ 103 | Compute and return the accumulated AUROC. 104 | """ 105 | weights = dim_zero_cat(self.weights) 106 | predictions = dim_zero_cat(self.predictions) 107 | target = dim_zero_cat(self.target).type_as(predictions) 108 | 109 | negative_mask = target <= self.label_threshold 110 | positive_mask = torch.logical_not(negative_mask) 111 | 112 | if not negative_mask.any(): 113 | msg = "Negative class missing. AUROC returned will be meaningless." 114 | if self.raise_missing_class: 115 | raise ValueError(msg) 116 | else: 117 | logging.warn(msg) 118 | if not positive_mask.any(): 119 | msg = "Positive class missing. AUROC returned will be meaningless." 120 | if self.raise_missing_class: 121 | raise ValueError(msg) 122 | else: 123 | logging.warn(msg) 124 | 125 | weighted_actual_negative_sum = torch.sum( 126 | torch.where(negative_mask, weights, torch.zeros_like(weights)) 127 | ) 128 | 129 | weighted_actual_positive_sum = torch.sum( 130 | torch.where(positive_mask, weights, torch.zeros_like(weights)) 131 | ) 132 | 133 | max_positive_negative_weighted_sum = torch.max( 134 | weighted_actual_negative_sum, weighted_actual_positive_sum 135 | ) 136 | 137 | min_positive_negative_weighted_sum = torch.min( 138 | weighted_actual_negative_sum, weighted_actual_positive_sum 139 | ) 140 | 141 | # Compute auroc with the weight set to 1 when positive & negative have identical scores. 142 | auroc_le = _compute_helper( 143 | target=target, 144 | weights=weights, 145 | predictions=predictions, 146 | min_positive_negative_weighted_sum=min_positive_negative_weighted_sum, 147 | max_positive_negative_weighted_sum=max_positive_negative_weighted_sum, 148 | equal_predictions_as_incorrect=False, 149 | ) 150 | 151 | # Compute auroc with the weight set to 0 when positive & negative have identical scores. 152 | auroc_lt = _compute_helper( 153 | target=target, 154 | weights=weights, 155 | predictions=predictions, 156 | min_positive_negative_weighted_sum=min_positive_negative_weighted_sum, 157 | max_positive_negative_weighted_sum=max_positive_negative_weighted_sum, 158 | equal_predictions_as_incorrect=True, 159 | ) 160 | 161 | # Compute auroc with the weight set to 1/2 when positive & negative have identical scores. 162 | return auroc_le - (auroc_le - auroc_lt) / 2.0 163 | -------------------------------------------------------------------------------- /ml_logging/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twitter/the-algorithm-ml/b85210863f7a94efded0ef5c5ccf4ff42767876c/ml_logging/__init__.py -------------------------------------------------------------------------------- /ml_logging/absl_logging.py: -------------------------------------------------------------------------------- 1 | """Sets up logging through absl for training usage. 2 | 3 | - Redirects logging to sys.stdout so that severity levels in GCP Stackdriver are accurate. 4 | 5 | Usage: 6 | >>> from twitter.ml.logging.absl_logging import logging 7 | >>> logging.info(f"Properly logged as INFO level in GCP Stackdriver.") 8 | 9 | """ 10 | import logging as py_logging 11 | import sys 12 | 13 | from absl import logging as logging 14 | 15 | 16 | def setup_absl_logging(): 17 | """Make sure that absl logging pushes to stdout rather than stderr.""" 18 | logging.get_absl_handler().python_handler.stream = sys.stdout 19 | formatter = py_logging.Formatter( 20 | fmt="[%(module)s.%(funcName)s:%(lineno)s - %(levelname)s] %(message)s" 21 | ) 22 | logging.get_absl_handler().setFormatter(formatter) 23 | logging.set_verbosity(logging.INFO) 24 | 25 | 26 | setup_absl_logging() 27 | -------------------------------------------------------------------------------- /ml_logging/test_torch_logging.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from tml.ml_logging.torch_logging import logging 4 | 5 | 6 | class Testtlogging(unittest.TestCase): 7 | def test_warn_once(self): 8 | with self.assertLogs(level="INFO") as captured_logs: 9 | logging.info("first info") 10 | logging.warning("first warning") 11 | logging.warning("first warning") 12 | logging.info("second info") 13 | 14 | self.assertEqual( 15 | captured_logs.output, 16 | [ 17 | "INFO:absl:first info", 18 | "WARNING:absl:first warning", 19 | "INFO:absl:second info", 20 | ], 21 | ) 22 | -------------------------------------------------------------------------------- /ml_logging/torch_logging.py: -------------------------------------------------------------------------------- 1 | """Overrides absl logger to be rank-aware for distributed pytorch usage. 2 | 3 | >>> # in-bazel import 4 | >>> from twitter.ml.logging.torch_logging import logging 5 | >>> # out-bazel import 6 | >>> from ml.logging.torch_logging import logging 7 | >>> logging.info(f"This only prints on rank 0 if distributed, otherwise prints normally.") 8 | >>> logging.info(f"This prints on all ranks if distributed, otherwise prints normally.", rank=-1) 9 | 10 | """ 11 | import functools 12 | from typing import Optional 13 | 14 | from tml.ml_logging.absl_logging import logging as logging 15 | from absl import logging as absl_logging 16 | 17 | import torch.distributed as dist 18 | 19 | 20 | def rank_specific(logger): 21 | """Ensures that we only override a given logger once.""" 22 | if hasattr(logger, "_ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC"): 23 | return logger 24 | 25 | def _if_rank(logger_method, limit: Optional[int] = None): 26 | if limit: 27 | # If we are limiting redundant logs, wrap logging call with a cache 28 | # to not execute if already cached. 29 | def _wrap(_call): 30 | @functools.lru_cache(limit) 31 | def _logger_method(*args, **kwargs): 32 | _call(*args, **kwargs) 33 | 34 | return _logger_method 35 | 36 | logger_method = _wrap(logger_method) 37 | 38 | def _inner(msg, *args, rank: int = 0, **kwargs): 39 | if not dist.is_initialized(): 40 | logger_method(msg, *args, **kwargs) 41 | elif dist.get_rank() == rank: 42 | logger_method(msg, *args, **kwargs) 43 | elif rank < 0: 44 | logger_method(f"Rank{dist.get_rank()}: {msg}", *args, **kwargs) 45 | 46 | # Register this stack frame with absl logging so that it doesn't trample logging lines. 47 | absl_logging.ABSLLogger.register_frame_to_skip(__file__, _inner.__name__) 48 | 49 | return _inner 50 | 51 | logger.fatal = _if_rank(logger.fatal) 52 | logger.error = _if_rank(logger.error) 53 | logger.warning = _if_rank(logger.warning, limit=1) 54 | logger.info = _if_rank(logger.info) 55 | logger.debug = _if_rank(logger.debug) 56 | logger.exception = _if_rank(logger.exception) 57 | 58 | logger._ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC = True 59 | 60 | 61 | rank_specific(logging) 62 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | """Wraps servable model in loss and RecapBatch passing to be trainable.""" 2 | # flake8: noqa 3 | from typing import Callable 4 | 5 | from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined] 6 | 7 | import torch 8 | import torch.distributed as dist 9 | from torchrec.distributed.model_parallel import DistributedModelParallel 10 | 11 | 12 | class ModelAndLoss(torch.nn.Module): 13 | # Reconsider our approach at a later date: https://ppwwyyxx.com/blog/2022/Loss-Function-Separation/ 14 | 15 | def __init__( 16 | self, 17 | model, 18 | loss_fn: Callable, 19 | ) -> None: 20 | """ 21 | Args: 22 | model: torch module to wrap. 23 | loss_fn: Function for calculating loss, should accept logits and labels. 24 | """ 25 | super().__init__() 26 | self.model = model 27 | self.loss_fn = loss_fn 28 | 29 | def forward(self, batch: "RecapBatch"): # type: ignore[name-defined] 30 | """Runs model forward and calculates loss according to given loss_fn. 31 | 32 | NOTE: The input signature here needs to be a Pipelineable object for 33 | prefetching purposes during training using torchrec's pipeline. However 34 | the underlying model signature needs to be exportable to onnx, requiring 35 | generic python types. see https://pytorch.org/docs/stable/onnx.html#types. 36 | 37 | """ 38 | outputs = self.model(batch) 39 | losses = self.loss_fn(outputs["logits"], batch.labels.float(), batch.weights.float()) 40 | 41 | outputs.update( 42 | { 43 | "loss": losses, 44 | "labels": batch.labels, 45 | "weights": batch.weights, 46 | } 47 | ) 48 | 49 | # Allow multiple losses. 50 | return losses, outputs 51 | 52 | 53 | def maybe_shard_model( 54 | model, 55 | device: torch.device, 56 | ): 57 | """Set up and apply DistributedModelParallel to a model if running in a distributed environment. 58 | 59 | If in a distributed environment, constructs Topology, sharders, and ShardingPlan, then applies 60 | DistributedModelParallel. 61 | 62 | If not in a distributed environment, returns model directly. 63 | """ 64 | if dist.is_initialized(): 65 | logging.info("***** Wrapping in DistributedModelParallel *****") 66 | logging.info(f"Model before wrapping: {model}") 67 | model = DistributedModelParallel( 68 | module=model, 69 | device=device, 70 | ) 71 | logging.info(f"Model after wrapping: {model}") 72 | 73 | return model 74 | 75 | 76 | def log_sharded_tensor_content(weight_name: str, table_name: str, weight_tensor) -> None: 77 | """Handy function to log the content of EBC embedding layer. 78 | Only works for single GPU machines. 79 | 80 | Args: 81 | weight_name: name of tensor, as defined in model 82 | table_name: name of the EBC table the weight is taken from 83 | weight_tensor: embedding weight tensor 84 | """ 85 | logging.info(f"{weight_name}, {table_name}", rank=-1) 86 | logging.info(f"{weight_tensor.metadata()}", rank=-1) 87 | output_tensor = torch.zeros(*weight_tensor.size(), device=torch.device("cuda:0")) 88 | weight_tensor.gather(out=output_tensor) 89 | logging.info(f"{output_tensor}", rank=-1) 90 | -------------------------------------------------------------------------------- /optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from tml.optimizers.optimizer import compute_lr 2 | -------------------------------------------------------------------------------- /optimizers/config.py: -------------------------------------------------------------------------------- 1 | """Optimization configurations for models.""" 2 | 3 | import typing 4 | 5 | import tml.core.config as base_config 6 | 7 | import pydantic 8 | 9 | 10 | class PiecewiseConstant(base_config.BaseConfig): 11 | learning_rate_boundaries: typing.List[int] = pydantic.Field(None) 12 | learning_rate_values: typing.List[float] = pydantic.Field(None) 13 | 14 | 15 | class LinearRampToConstant(base_config.BaseConfig): 16 | learning_rate: float 17 | num_ramp_steps: pydantic.PositiveInt = pydantic.Field( 18 | description="Number of steps to ramp this up from zero." 19 | ) 20 | 21 | 22 | class LinearRampToCosine(base_config.BaseConfig): 23 | learning_rate: float 24 | final_learning_rate: float 25 | num_ramp_steps: pydantic.PositiveInt = pydantic.Field( 26 | description="Number of steps to ramp this up from zero." 27 | ) 28 | final_num_steps: pydantic.PositiveInt = pydantic.Field( 29 | description="Final number of steps where decay stops." 30 | ) 31 | 32 | 33 | class LearningRate(base_config.BaseConfig): 34 | constant: float = pydantic.Field(None, one_of="lr") 35 | linear_ramp_to_cosine: LinearRampToCosine = pydantic.Field(None, one_of="lr") 36 | linear_ramp_to_constant: LinearRampToConstant = pydantic.Field(None, one_of="lr") 37 | piecewise_constant: PiecewiseConstant = pydantic.Field(None, one_of="lr") 38 | 39 | 40 | class OptimizerAlgorithmConfig(base_config.BaseConfig): 41 | """Base class for optimizer configurations.""" 42 | 43 | lr: float 44 | ... 45 | 46 | 47 | class AdamConfig(OptimizerAlgorithmConfig): 48 | # see https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam 49 | lr: float 50 | betas: typing.Tuple[float, float] = [0.9, 0.999] 51 | eps: float = 1e-7 # Numerical stability in denominator. 52 | 53 | 54 | class SgdConfig(OptimizerAlgorithmConfig): 55 | lr: float 56 | momentum: float = 0.0 57 | 58 | 59 | class AdagradConfig(OptimizerAlgorithmConfig): 60 | lr: float 61 | eps: float = 0 62 | 63 | 64 | class OptimizerConfig(base_config.BaseConfig): 65 | learning_rate: LearningRate = pydantic.Field( 66 | None, 67 | description="Constant learning rates", 68 | ) 69 | adam: AdamConfig = pydantic.Field(None, one_of="optimizer") 70 | sgd: SgdConfig = pydantic.Field(None, one_of="optimizer") 71 | adagrad: AdagradConfig = pydantic.Field(None, one_of="optimizer") 72 | 73 | 74 | def get_optimizer_algorithm_config(optimizer_config: OptimizerConfig): 75 | if optimizer_config.adam is not None: 76 | return optimizer_config.adam 77 | elif optimizer_config.sgd is not None: 78 | return optimizer_config.sgd 79 | elif optimizer_config.adagrad is not None: 80 | return optimizer_config.adagrad 81 | else: 82 | raise ValueError(f"No optimizer selected in optimizer_config, passed {optimizer_config}") 83 | -------------------------------------------------------------------------------- /optimizers/optimizer.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple 2 | import math 3 | import bisect 4 | 5 | from tml.optimizers.config import ( 6 | LearningRate, 7 | OptimizerConfig, 8 | ) 9 | 10 | import torch 11 | from torch.optim import Optimizer 12 | from torch.optim.lr_scheduler import _LRScheduler 13 | from tml.ml_logging.torch_logging import logging 14 | 15 | 16 | def compute_lr(lr_config, step): 17 | """Compute a learning rate.""" 18 | if lr_config.constant is not None: 19 | return lr_config.constant 20 | elif lr_config.piecewise_constant is not None: 21 | return lr_config.piecewise_constant.learning_rate_values[ 22 | bisect.bisect_right(lr_config.piecewise_constant.learning_rate_boundaries, step) 23 | ] 24 | elif lr_config.linear_ramp_to_constant is not None: 25 | slope = ( 26 | lr_config.linear_ramp_to_constant.learning_rate 27 | / lr_config.linear_ramp_to_constant.num_ramp_steps 28 | ) 29 | return min(lr_config.linear_ramp_to_constant.learning_rate, slope * step) 30 | elif lr_config.linear_ramp_to_cosine is not None: 31 | cfg = lr_config.linear_ramp_to_cosine 32 | if step < cfg.num_ramp_steps: 33 | slope = cfg.learning_rate / cfg.num_ramp_steps 34 | return slope * step 35 | elif step <= cfg.final_num_steps: 36 | return cfg.final_learning_rate + (cfg.learning_rate - cfg.final_learning_rate) * 0.5 * ( 37 | 1.0 38 | + math.cos( 39 | math.pi * (step - cfg.num_ramp_steps) / (cfg.final_num_steps - cfg.num_ramp_steps) 40 | ) 41 | ) 42 | else: 43 | return cfg.final_learning_rate 44 | else: 45 | raise ValueError(f"No option selected in lr_config, passed {lr_config}") 46 | 47 | 48 | class LRShim(_LRScheduler): 49 | """Shim to get learning rates into a LRScheduler. 50 | 51 | This adheres to the torch.optim scheduler API and can be plugged anywhere that 52 | e.g. exponential decay can be used. 53 | """ 54 | 55 | def __init__( 56 | self, 57 | optimizer, 58 | lr_dict: Dict[str, LearningRate], 59 | last_epoch=-1, 60 | verbose=False, 61 | ): 62 | self.optimizer = optimizer 63 | self.lr_dict = lr_dict 64 | self.group_names = list(self.lr_dict.keys()) 65 | 66 | num_param_groups = sum(1 for _, _optim in optimizer._optims for _ in _optim.param_groups) 67 | if num_param_groups != len(lr_dict): 68 | raise ValueError( 69 | f"Optimizer had {len(optimizer.param_groups)}, but config had {len(lr_dict)}." 70 | ) 71 | 72 | super().__init__(optimizer, last_epoch, verbose) 73 | 74 | def get_lr(self): 75 | if not self._get_lr_called_within_step: 76 | logging.warn( 77 | "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", 78 | UserWarning, 79 | ) 80 | return self._get_closed_form_lr() 81 | 82 | def _get_closed_form_lr(self): 83 | return [compute_lr(lr_config, self.last_epoch) for lr_config in self.lr_dict.values()] 84 | 85 | 86 | def get_optimizer_class(optimizer_config: OptimizerConfig): 87 | if optimizer_config.adam is not None: 88 | return torch.optim.Adam 89 | elif optimizer_config.sgd is not None: 90 | return torch.optim.SGD 91 | elif optimizer_config.adagrad is not None: 92 | return torch.optim.Adagrad 93 | 94 | 95 | def build_optimizer( 96 | model: torch.nn.Module, optimizer_config: OptimizerConfig 97 | ) -> Tuple[Optimizer, _LRScheduler]: 98 | """Builds an optimizer and LR scheduler from an OptimizerConfig. 99 | Note: use this when you want the same optimizer and learning rate schedule for all your parameters. 100 | """ 101 | optimizer_class = get_optimizer_class(optimizer_config) 102 | optimizer = optimizer_class(model.parameters(), **optimizer_config.sgd.dict()) 103 | # We're passing everything in as one group here 104 | scheduler = LRShim(optimizer, lr_dict={"ALL_PARAMS": optimizer_config.learning_rate}) 105 | return optimizer, scheduler 106 | -------------------------------------------------------------------------------- /projects/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twitter/the-algorithm-ml/b85210863f7a94efded0ef5c5ccf4ff42767876c/projects/__init__.py -------------------------------------------------------------------------------- /projects/home/recap/README.md: -------------------------------------------------------------------------------- 1 | # Heavy Ranker 2 | 3 | ## Overview 4 | 5 | The heavy ranker is a machine learning model used to rank tweets for the "For You" timeline 6 | which have passed through the candidate retrieval stage. It is one of the final stages of the funnel, 7 | succeeded primarily by a set of filtering heuristics. 8 | 9 | The model receives features describing a Tweet and the user that the Tweet is being recommended to 10 | (see [FEATURES.md](./FEATURES.md)). The model architecture is a parallel [MaskNet](https://arxiv.org/abs/2102.07619) 11 | which outputs a set of numbers between 0 and 1, with each output representing the probability that the user 12 | will engage with the tweet in a particular way. The predicted engagement types are explained below: 13 | ``` 14 | scored_tweets_model_weight_fav: The probability the user will favorite the Tweet. 15 | scored_tweets_model_weight_retweet: The probability the user will Retweet the Tweet. 16 | scored_tweets_model_weight_reply: The probability the user replies to the Tweet. 17 | scored_tweets_model_weight_good_profile_click: The probability the user opens the Tweet author profile and Likes or replies to a Tweet. 18 | scored_tweets_model_weight_video_playback50: The probability (for a video Tweet) that the user will watch at least half of the video. 19 | scored_tweets_model_weight_reply_engaged_by_author: The probability the user replies to the Tweet and this reply is engaged by the Tweet author. 20 | scored_tweets_model_weight_good_click: The probability the user will click into the conversation of this Tweet and reply or Like a Tweet. 21 | scored_tweets_model_weight_good_click_v2: The probability the user will click into the conversation of this Tweet and stay there for at least 2 minutes. 22 | scored_tweets_model_weight_negative_feedback_v2: The probability the user will react negatively (requesting "show less often" on the Tweet or author, block or mute the Tweet author). 23 | scored_tweets_model_weight_report: The probability the user will click Report Tweet. 24 | ``` 25 | 26 | The outputs of the model are combined into a final model score by doing a weighted sum across the predicted engagement probabilities. 27 | The weight of each engagement probability comes from a configuration file, read by the serving stack 28 | [here](https://github.com/twitter/the-algorithm/blob/main/home-mixer/server/src/main/scala/com/twitter/home_mixer/product/scored_tweets/param/ScoredTweetsParam.scala#L84). The exact weights in the file can be adjusted at any time, but the current weighting of probabilities 29 | (April 5, 2023) is as follows: 30 | ``` 31 | scored_tweets_model_weight_fav: 0.5 32 | scored_tweets_model_weight_retweet: 1.0 33 | scored_tweets_model_weight_reply: 13.5 34 | scored_tweets_model_weight_good_profile_click: 12.0 35 | scored_tweets_model_weight_video_playback50: 0.005 36 | scored_tweets_model_weight_reply_engaged_by_author: 75.0 37 | scored_tweets_model_weight_good_click: 11.0 38 | scored_tweets_model_weight_good_click_v2: 10.0 39 | scored_tweets_model_weight_negative_feedback_v2: -74.0 40 | scored_tweets_model_weight_report: -369.0 41 | ``` 42 | 43 | Essentially, the formula is: 44 | ``` 45 | score = sum_i { (weight of engagement i) * (probability of engagement i) } 46 | ``` 47 | 48 | Since each engagement has a different average probability, the weights were originally set so that, 49 | on average, each weighted engagement probability contributes a near-equal amount to the score. 50 | Since then, we have periodically adjusted the weights to optimize for platform metrics. 51 | 52 | Some disclaimers: 53 | - Due to the need to make sure this runs independently from other parts of Twitter codebase, there may be small differences from the production model. 54 | - We cannot release the real training data due to privacy restrictions. However, we have included a script to generate random data to ensure you can run the model training code. 55 | 56 | ## Development 57 | After following the repo setup instructions, you can run the following script from a virtual environment to create a 58 | random training dataset in `$HOME/tmp/recap_local_random_data`: 59 | ```sh 60 | projects/home/recap/scripts/create_random_data.sh 61 | ``` 62 | 63 | You can then train the model using the following script. 64 | Checkpoints and logs will be written to `$HOME/tmp/runs/recap_local_debug`: 65 | ```sh 66 | projects/home/recap/scripts/run_local.sh 67 | ``` 68 | 69 | The model training can be configured in `projects/home/recap/config/local_prod.yaml` 70 | -------------------------------------------------------------------------------- /projects/home/recap/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twitter/the-algorithm-ml/b85210863f7a94efded0ef5c5ccf4ff42767876c/projects/home/recap/__init__.py -------------------------------------------------------------------------------- /projects/home/recap/config.py: -------------------------------------------------------------------------------- 1 | from tml.core import config as config_mod 2 | import tml.projects.home.recap.data.config as data_config 3 | import tml.projects.home.recap.model.config as model_config 4 | import tml.projects.home.recap.optimizer.config as optimizer_config 5 | 6 | from enum import Enum 7 | from typing import Dict, Optional 8 | import pydantic 9 | 10 | 11 | class TrainingConfig(config_mod.BaseConfig): 12 | save_dir: str = "/tmp/model" 13 | num_train_steps: pydantic.PositiveInt = 1000000 14 | initial_checkpoint_dir: str = pydantic.Field( 15 | None, description="Directory of initial checkpoints", at_most_one_of="initialization" 16 | ) 17 | checkpoint_every_n: pydantic.PositiveInt = 1000 18 | checkpoint_max_to_keep: pydantic.PositiveInt = pydantic.Field( 19 | None, description="Maximum number of checkpoints to keep. Defaults to keeping all." 20 | ) 21 | train_log_every_n: pydantic.PositiveInt = 1000 22 | num_eval_steps: int = pydantic.Field( 23 | 16384, description="Number of evaluation steps. If < 0 the entire dataset " "will be used." 24 | ) 25 | eval_log_every_n: pydantic.PositiveInt = 5000 26 | 27 | eval_timeout_in_s: pydantic.PositiveFloat = 60 * 60 28 | 29 | gradient_accumulation: int = pydantic.Field( 30 | None, description="Number of replica steps to accumulate gradients." 31 | ) 32 | 33 | 34 | class RecapConfig(config_mod.BaseConfig): 35 | training: TrainingConfig = pydantic.Field(TrainingConfig()) 36 | model: model_config.ModelConfig 37 | train_data: data_config.RecapDataConfig 38 | validation_data: Dict[str, data_config.RecapDataConfig] 39 | optimizer: optimizer_config.RecapOptimizerConfig 40 | 41 | which_metrics: Optional[str] = pydantic.Field(None, description="which metrics to pick.") 42 | 43 | # DANGER DANGER! You might expect validators here to ensure that multi task learning setups are 44 | # the same as the data. Unfortunately, this throws opaque errors when the model configuration is 45 | # invalid. In our judgement, that is a more frequency and worse occurrence than tasks not matching 46 | # the data. 47 | 48 | 49 | class JobMode(str, Enum): 50 | """Job modes.""" 51 | 52 | TRAIN = "train" 53 | EVALUATE = "evaluate" 54 | INFERENCE = "inference" 55 | -------------------------------------------------------------------------------- /projects/home/recap/config/home_recap_2022/segdense.json: -------------------------------------------------------------------------------- 1 | { 2 | "schema": [ 3 | { 4 | "dtype": "int64_list", 5 | "feature_name": "home_recap_2022_discrete__segdense_vals", 6 | "length": 320 7 | }, 8 | { 9 | "dtype": "float_list", 10 | "feature_name": "home_recap_2022_cont__segdense_vals", 11 | "length": 6000 12 | }, 13 | { 14 | "dtype": "int64_list", 15 | "feature_name": "home_recap_2022_binary__segdense_vals", 16 | "length": 512 17 | }, 18 | { 19 | "dtype": "int64_list", 20 | "feature_name": "recap.engagement.is_tweet_detail_dwelled_15_sec", 21 | "length": 1 22 | }, 23 | { 24 | "dtype": "int64_list", 25 | "feature_name": "recap.engagement.is_profile_clicked_and_profile_engaged", 26 | "length": 1 27 | }, 28 | { 29 | "dtype": "int64_list", 30 | "feature_name": "recap.engagement.is_replied_reply_engaged_by_author", 31 | "length": 1 32 | }, 33 | { 34 | "dtype": "int64_list", 35 | "feature_name": "recap.engagement.is_video_playback_50", 36 | "length": 1 37 | }, 38 | { 39 | "dtype": "int64_list", 40 | "feature_name": "recap.engagement.is_report_tweet_clicked", 41 | "length": 1 42 | }, 43 | { 44 | "dtype": "int64_list", 45 | "feature_name": "recap.engagement.is_replied", 46 | "length": 1 47 | }, 48 | { 49 | "dtype": "int64_list", 50 | "feature_name": "meta.author_id", 51 | "length": 1 52 | }, 53 | { 54 | "dtype": "int64_list", 55 | "feature_name": "recap.engagement.is_negative_feedback_v2", 56 | "length": 1 57 | }, 58 | { 59 | "dtype": "int64_list", 60 | "feature_name": "recap.engagement.is_retweeted", 61 | "length": 1 62 | }, 63 | { 64 | "dtype": "int64_list", 65 | "feature_name": "recap.engagement.is_favorited", 66 | "length": 1 67 | }, 68 | { 69 | "dtype": "int64_list", 70 | "feature_name": "recap.engagement.is_good_clicked_convo_desc_favorited_or_replied", 71 | "length": 1 72 | }, 73 | { 74 | "dtype": "int64_list", 75 | "feature_name": "meta.tweet_id", 76 | "length": 1 77 | }, 78 | { 79 | "dtype": "int64_list", 80 | "feature_name": "recap.engagement.is_good_clicked_convo_desc_v2", 81 | "length": 1 82 | }, 83 | { 84 | "dtype": "int64_list", 85 | "feature_name": "meta.user_id", 86 | "length": 1 87 | }, 88 | { 89 | "dtype": "int64_list", 90 | "feature_name": "recap.engagement.is_bookmarked", 91 | "length": 1 92 | }, 93 | { 94 | "dtype": "int64_list", 95 | "feature_name": "recap.engagement.is_shared", 96 | "length": 1 97 | }, 98 | { 99 | "dtype": "float_list", 100 | "feature_name": "user.timelines.twhin_user_engagement_embeddings.twhin_user_engagement_embeddings", 101 | "length": 200 102 | }, 103 | { 104 | "dtype": "float_list", 105 | "feature_name": "original_author.timelines.twhin_author_follow_embeddings.twhin_author_follow_embeddings", 106 | "length": 200 107 | }, 108 | { 109 | "dtype": "float_list", 110 | "feature_name": "user.timelines.twhin_user_follow_embeddings.twhin_user_follow_embeddings", 111 | "length": 200 112 | } 113 | ] 114 | } -------------------------------------------------------------------------------- /projects/home/recap/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twitter/the-algorithm-ml/b85210863f7a94efded0ef5c5ccf4ff42767876c/projects/home/recap/data/__init__.py -------------------------------------------------------------------------------- /projects/home/recap/data/config.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from enum import Enum 3 | 4 | 5 | from tml.core import config as base_config 6 | 7 | import pydantic 8 | 9 | 10 | class ExplicitDateInputs(base_config.BaseConfig): 11 | """Arguments to select train/validation data using end_date and days of data.""" 12 | 13 | data_root: str = pydantic.Field(..., description="Data path prefix.") 14 | end_date: str = pydantic.Field(..., description="Data end date, inclusive.") 15 | days: int = pydantic.Field(..., description="Number of days of data for dataset.") 16 | num_missing_days_tol: int = pydantic.Field( 17 | 0, description="We tolerate <= num_missing_days_tol days of missing data." 18 | ) 19 | 20 | 21 | class ExplicitDatetimeInputs(base_config.BaseConfig): 22 | """Arguments to select train/validation data using end_datetime and hours of data.""" 23 | 24 | data_root: str = pydantic.Field(..., description="Data path prefix.") 25 | end_datetime: str = pydantic.Field(..., description="Data end datetime, inclusive.") 26 | hours: int = pydantic.Field(..., description="Number of hours of data for dataset.") 27 | num_missing_hours_tol: int = pydantic.Field( 28 | 0, description="We tolerate <= num_missing_hours_tol hours of missing data." 29 | ) 30 | 31 | 32 | class DdsCompressionOption(str, Enum): 33 | """The only valid compression option is 'AUTO'""" 34 | 35 | AUTO = "AUTO" 36 | 37 | 38 | class DatasetConfig(base_config.BaseConfig): 39 | inputs: str = pydantic.Field( 40 | None, description="A glob for selecting data.", one_of="date_inputs_format" 41 | ) 42 | explicit_datetime_inputs: ExplicitDatetimeInputs = pydantic.Field( 43 | None, one_of="date_inputs_format" 44 | ) 45 | explicit_date_inputs: ExplicitDateInputs = pydantic.Field(None, one_of="date_inputs_format") 46 | 47 | global_batch_size: pydantic.PositiveInt 48 | 49 | num_files_to_keep: pydantic.PositiveInt = pydantic.Field( 50 | None, description="Number of shards to keep." 51 | ) 52 | repeat_files: bool = pydantic.Field( 53 | True, description="DEPRICATED. Files are repeated no matter what this is set to." 54 | ) 55 | file_batch_size: pydantic.PositiveInt = pydantic.Field(16, description="File batch size") 56 | 57 | cache: bool = pydantic.Field( 58 | False, 59 | description="Cache dataset in memory. Careful to only use this when you" 60 | " have enough memory to fit entire dataset.", 61 | ) 62 | 63 | data_service_dispatcher: str = pydantic.Field(None) 64 | ignore_data_errors: bool = pydantic.Field( 65 | False, description="Whether to ignore tf.data errors. DANGER DANGER, may wedge jobs." 66 | ) 67 | dataset_service_compression: DdsCompressionOption = pydantic.Field( 68 | None, 69 | description="Compress the dataset for DDS worker -> training host. Disabled by default and the only valid option is 'AUTO'", 70 | ) 71 | 72 | # tf.data.Dataset options 73 | examples_shuffle_buffer_size: int = pydantic.Field(1024, description="Size of shuffle buffers.") 74 | map_num_parallel_calls: pydantic.PositiveInt = pydantic.Field( 75 | None, description="Number of parallel calls." 76 | ) 77 | interleave_num_parallel_calls: pydantic.PositiveInt = pydantic.Field( 78 | None, description="Number of shards to interleave." 79 | ) 80 | 81 | 82 | class TruncateAndSlice(base_config.BaseConfig): 83 | # Apply truncation and then slice. 84 | continuous_feature_truncation: pydantic.PositiveInt = pydantic.Field( 85 | None, description="Experimental. Truncates continuous features to this amount for efficiency." 86 | ) 87 | binary_feature_truncation: pydantic.PositiveInt = pydantic.Field( 88 | None, description="Experimental. Truncates binary features to this amount for efficiency." 89 | ) 90 | 91 | continuous_feature_mask_path: str = pydantic.Field( 92 | None, description="Path of mask used to slice input continuous features." 93 | ) 94 | binary_feature_mask_path: str = pydantic.Field( 95 | None, description="Path of mask used to slice input binary features." 96 | ) 97 | 98 | 99 | class DataType(str, Enum): 100 | BFLOAT16 = "bfloat16" 101 | BOOL = "bool" 102 | 103 | FLOAT32 = "float32" 104 | FLOAT16 = "float16" 105 | 106 | UINT8 = "uint8" 107 | 108 | 109 | class DownCast(base_config.BaseConfig): 110 | # Apply down casting to selected features. 111 | features: typing.Dict[str, DataType] = pydantic.Field( 112 | None, description="Map features to down cast data types." 113 | ) 114 | 115 | 116 | class TaskData(base_config.BaseConfig): 117 | pos_downsampling_rate: float = pydantic.Field( 118 | 1.0, 119 | description="Downsampling rate of positives used to generate dataset.", 120 | ) 121 | neg_downsampling_rate: float = pydantic.Field( 122 | 1.0, 123 | description="Downsampling rate of negatives used to generate dataset.", 124 | ) 125 | 126 | 127 | class SegDenseSchema(base_config.BaseConfig): 128 | schema_path: str = pydantic.Field(..., description="Path to feature config json.") 129 | features: typing.List[str] = pydantic.Field( 130 | [], 131 | description="List of features (in addition to the renamed features) to read from schema path above.", 132 | ) 133 | renamed_features: typing.Dict[str, str] = pydantic.Field( 134 | {}, description="Dictionary of renamed features." 135 | ) 136 | mask_mantissa_features: typing.Dict[str, int] = pydantic.Field( 137 | {}, 138 | description="(experimental) Number of mantissa bits to mask to simulate lower precision data.", 139 | ) 140 | 141 | 142 | class RectifyLabels(base_config.BaseConfig): 143 | label_rectification_window_in_hours: float = pydantic.Field( 144 | 3.0, description="overlap time in hours for which to flip labels" 145 | ) 146 | served_timestamp_field: str = pydantic.Field( 147 | ..., description="input field corresponding to served time" 148 | ) 149 | impressed_timestamp_field: str = pydantic.Field( 150 | ..., description="input field corresponding to impressed time" 151 | ) 152 | label_to_engaged_timestamp_field: typing.Dict[str, str] = pydantic.Field( 153 | ..., description="label to the input field corresponding to engagement time" 154 | ) 155 | 156 | 157 | class ExtractFeaturesRow(base_config.BaseConfig): 158 | name: str = pydantic.Field( 159 | ..., 160 | description="name of the new field name to be created", 161 | ) 162 | source_tensor: str = pydantic.Field( 163 | ..., 164 | description="name of the dense tensor to look for the feature", 165 | ) 166 | index: int = pydantic.Field( 167 | ..., 168 | description="index of the feature in the dense tensor", 169 | ) 170 | 171 | 172 | class ExtractFeatures(base_config.BaseConfig): 173 | extract_feature_table: typing.List[ExtractFeaturesRow] = pydantic.Field( 174 | [], 175 | description="list of features to be extracted with their name, source tensor and index", 176 | ) 177 | 178 | 179 | class DownsampleNegatives(base_config.BaseConfig): 180 | batch_multiplier: int = pydantic.Field( 181 | None, 182 | description="batch multiplier", 183 | ) 184 | engagements_list: typing.List[str] = pydantic.Field( 185 | [], 186 | description="engagements with kept positives", 187 | ) 188 | num_engagements: int = pydantic.Field( 189 | ..., 190 | description="number engagements used in the model, including ones excluded in engagements_list", 191 | ) 192 | 193 | 194 | class Preprocess(base_config.BaseConfig): 195 | truncate_and_slice: TruncateAndSlice = pydantic.Field(None, description="Truncation and slicing.") 196 | downcast: DownCast = pydantic.Field(None, description="Down cast to features.") 197 | rectify_labels: RectifyLabels = pydantic.Field( 198 | None, description="Rectify labels for a given overlap window" 199 | ) 200 | extract_features: ExtractFeatures = pydantic.Field( 201 | None, description="Extract features from dense tensors." 202 | ) 203 | downsample_negatives: DownsampleNegatives = pydantic.Field( 204 | None, description="Downsample negatives." 205 | ) 206 | 207 | 208 | class Sampler(base_config.BaseConfig): 209 | """Assumes function is defined in data/samplers.py. 210 | 211 | Only use this for quick experimentation. 212 | If samplers are useful, we should sample from upstream data generation. 213 | 214 | DEPRICATED, DO NOT USE. 215 | """ 216 | 217 | name: str 218 | kwargs: typing.Dict 219 | 220 | 221 | class RecapDataConfig(DatasetConfig): 222 | seg_dense_schema: SegDenseSchema 223 | 224 | tasks: typing.Dict[str, TaskData] = pydantic.Field( 225 | description="Description of individual tasks in this dataset." 226 | ) 227 | evaluation_tasks: typing.List[str] = pydantic.Field( 228 | [], description="If specified, lists the tasks we're generating metrics for." 229 | ) 230 | 231 | preprocess: Preprocess = pydantic.Field( 232 | None, description="Function run in tf.data.Dataset at train/eval, in-graph at inference." 233 | ) 234 | 235 | sampler: Sampler = pydantic.Field( 236 | None, 237 | description="""DEPRICATED, DO NOT USE. Sampling function for offline experiments.""", 238 | ) 239 | 240 | @pydantic.root_validator() 241 | def _validate_evaluation_tasks(cls, values): 242 | if values.get("evaluation_tasks") is not None: 243 | for task in values["evaluation_tasks"]: 244 | if task not in values["tasks"]: 245 | raise KeyError(f"Evaluation task {task} must be in tasks. Received {values['tasks']}") 246 | return values 247 | -------------------------------------------------------------------------------- /projects/home/recap/data/generate_random_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from absl import app, flags, logging 4 | import tensorflow as tf 5 | from typing import Dict 6 | 7 | from tml.projects.home.recap.data import tfe_parsing 8 | from tml.core import config as tml_config_mod 9 | import tml.projects.home.recap.config as recap_config_mod 10 | 11 | flags.DEFINE_string("config_path", None, "Path to hyperparameters for model.") 12 | flags.DEFINE_integer("n_examples", 100, "Numer of examples to generate.") 13 | 14 | FLAGS = flags.FLAGS 15 | 16 | 17 | def _generate_random_example( 18 | tf_example_schema: Dict[str, tf.io.FixedLenFeature] 19 | ) -> Dict[str, tf.Tensor]: 20 | example = {} 21 | for feature_name, feature_spec in tf_example_schema.items(): 22 | dtype = feature_spec.dtype 23 | if (dtype == tf.int64) or (dtype == tf.int32): 24 | x = tf.experimental.numpy.random.randint(0, high=10, size=feature_spec.shape, dtype=dtype) 25 | elif (dtype == tf.float32) or (dtype == tf.float64): 26 | x = tf.random.uniform(shape=[feature_spec.shape], dtype=dtype) 27 | else: 28 | raise NotImplementedError(f"Unknown type {dtype}") 29 | 30 | example[feature_name] = x 31 | 32 | return example 33 | 34 | 35 | def _float_feature(value): 36 | return tf.train.Feature(float_list=tf.train.FloatList(value=value)) 37 | 38 | 39 | def _int64_feature(value): 40 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 41 | 42 | 43 | def _serialize_example(x: Dict[str, tf.Tensor]) -> bytes: 44 | feature = {} 45 | serializers = {tf.float32: _float_feature, tf.int64: _int64_feature} 46 | for feature_name, tensor in x.items(): 47 | feature[feature_name] = serializers[tensor.dtype](tensor) 48 | 49 | example_proto = tf.train.Example(features=tf.train.Features(feature=feature)) 50 | return example_proto.SerializeToString() 51 | 52 | 53 | def generate_data(data_path: str, config: recap_config_mod.RecapConfig): 54 | with tf.io.gfile.GFile(config.train_data.seg_dense_schema.schema_path, "r") as f: 55 | seg_dense_schema = json.load(f)["schema"] 56 | 57 | tf_example_schema = tfe_parsing.create_tf_example_schema( 58 | config.train_data, 59 | seg_dense_schema, 60 | ) 61 | 62 | record_filename = os.path.join(data_path, "random.tfrecord.gz") 63 | 64 | with tf.io.TFRecordWriter(record_filename, "GZIP") as writer: 65 | random_example = _generate_random_example(tf_example_schema) 66 | serialized_example = _serialize_example(random_example) 67 | writer.write(serialized_example) 68 | 69 | 70 | def _generate_data_main(unused_argv): 71 | config = tml_config_mod.load_config_from_yaml(recap_config_mod.RecapConfig, FLAGS.config_path) 72 | 73 | # Find the path where to put the data 74 | data_path = os.path.dirname(config.train_data.inputs) 75 | logging.info("Putting random data in %s", data_path) 76 | 77 | generate_data(data_path, config) 78 | 79 | 80 | if __name__ == "__main__": 81 | app.run(_generate_data_main) 82 | -------------------------------------------------------------------------------- /projects/home/recap/data/preprocessors.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocessors applied on DDS workers in order to modify the dataset on the fly. 3 | Some of these preprocessors are also applied to the model at serving time. 4 | """ 5 | from tml.projects.home.recap import config as config_mod 6 | from absl import logging 7 | import tensorflow as tf 8 | import numpy as np 9 | 10 | 11 | class TruncateAndSlice(tf.keras.Model): 12 | """Class for truncating and slicing.""" 13 | 14 | def __init__(self, truncate_and_slice_config): 15 | super().__init__() 16 | self._truncate_and_slice_config = truncate_and_slice_config 17 | 18 | if self._truncate_and_slice_config.continuous_feature_mask_path: 19 | with tf.io.gfile.GFile( 20 | self._truncate_and_slice_config.continuous_feature_mask_path, "rb" 21 | ) as f: 22 | self._continuous_mask = np.load(f).nonzero()[0] 23 | logging.info(f"Slicing {np.sum(self._continuous_mask)} continuous features.") 24 | else: 25 | self._continuous_mask = None 26 | 27 | if self._truncate_and_slice_config.binary_feature_mask_path: 28 | with tf.io.gfile.GFile(self._truncate_and_slice_config.binary_feature_mask_path, "rb") as f: 29 | self._binary_mask = np.load(f).nonzero()[0] 30 | logging.info(f"Slicing {np.sum(self._binary_mask)} binary features.") 31 | else: 32 | self._binary_mask = None 33 | 34 | def call(self, inputs, training=None, mask=None): 35 | outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs)) 36 | if self._truncate_and_slice_config.continuous_feature_truncation: 37 | logging.info("Truncating continuous") 38 | outputs["continuous"] = outputs["continuous"][ 39 | :, : self._truncate_and_slice_config.continuous_feature_truncation 40 | ] 41 | if self._truncate_and_slice_config.binary_feature_truncation: 42 | logging.info("Truncating binary") 43 | outputs["binary"] = outputs["binary"][ 44 | :, : self._truncate_and_slice_config.binary_feature_truncation 45 | ] 46 | if self._continuous_mask is not None: 47 | outputs["continuous"] = tf.gather(outputs["continuous"], self._continuous_mask, axis=1) 48 | if self._binary_mask is not None: 49 | outputs["binary"] = tf.gather(outputs["binary"], self._binary_mask, axis=1) 50 | return outputs 51 | 52 | 53 | class DownCast(tf.keras.Model): 54 | """Class for Down casting dataset before serialization and transferring to training host. 55 | Depends on the data type and the actual data range, the down casting can be lossless or not. 56 | It is strongly recommended to compare the metrics before and after down casting. 57 | """ 58 | 59 | def __init__(self, downcast_config): 60 | super().__init__() 61 | self.config = downcast_config 62 | self._type_map = { 63 | "bfloat16": tf.bfloat16, 64 | "bool": tf.bool, 65 | } 66 | 67 | def call(self, inputs, training=None, mask=None): 68 | outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs)) 69 | for feature, type_str in self.config.features.items(): 70 | assert type_str in self._type_map 71 | if type_str == "bfloat16": 72 | logging.warning( 73 | "Although bfloat16 and float32 have the same number of exponent bits, this down casting is not 100% lossless. Please double check metrics." 74 | ) 75 | down_cast_data_type = self._type_map[type_str] 76 | outputs[feature] = tf.cast(outputs[feature], dtype=down_cast_data_type) 77 | return outputs 78 | 79 | 80 | class RectifyLabels(tf.keras.Model): 81 | """Class for rectifying labels""" 82 | 83 | def __init__(self, rectify_label_config): 84 | super().__init__() 85 | self._config = rectify_label_config 86 | self._window = int(self._config.label_rectification_window_in_hours * 60 * 60 * 1000) 87 | 88 | def call(self, inputs, training=None, mask=None): 89 | served_ts_field = self._config.served_timestamp_field 90 | impressed_ts_field = self._config.impressed_timestamp_field 91 | 92 | for label, engaged_ts_field in self._config.label_to_engaged_timestamp_field.items(): 93 | impressed = inputs[impressed_ts_field] 94 | served = inputs[served_ts_field] 95 | engaged = inputs[engaged_ts_field] 96 | 97 | keep = tf.math.logical_and(inputs[label] > 0, impressed - served < self._window) 98 | keep = tf.math.logical_and(keep, engaged - served < self._window) 99 | inputs[label] = tf.where(keep, inputs[label], tf.zeros_like(inputs[label])) 100 | 101 | return inputs 102 | 103 | 104 | class ExtractFeatures(tf.keras.Model): 105 | """Class for extracting individual features from dense tensors by their index.""" 106 | 107 | def __init__(self, extract_features_config): 108 | super().__init__() 109 | self._config = extract_features_config 110 | 111 | def call(self, inputs, training=None, mask=None): 112 | 113 | for row in self._config.extract_feature_table: 114 | inputs[row.name] = inputs[row.source_tensor][:, row.index] 115 | 116 | return inputs 117 | 118 | 119 | class DownsampleNegatives(tf.keras.Model): 120 | """Class for down-sampling/dropping negatives and updating the weights. 121 | 122 | If inputs['fav'] = [1, 0, 0, 0] and inputs['weights'] = [1.0, 1.0, 1.0, 1.0] 123 | inputs are transformed to inputs['fav'] = [1, 0] and inputs['weights'] = [1.0, 3.0] 124 | when batch_multiplier=2 and engagements_list=['fav'] 125 | 126 | It supports multiple engagements (union/logical_or is used to aggregate engagements), so we don't 127 | drop positives for any engagement. 128 | """ 129 | 130 | def __init__(self, downsample_negatives_config): 131 | super().__init__() 132 | self.config = downsample_negatives_config 133 | 134 | def call(self, inputs, training=None, mask=None): 135 | labels = self.config.engagements_list 136 | # union of engagements 137 | mask = tf.squeeze(tf.reduce_any(tf.stack([inputs[label] == 1 for label in labels], 1), 1)) 138 | n_positives = tf.reduce_sum(tf.cast(mask, tf.int32)) 139 | batch_size = tf.cast(tf.shape(inputs[labels[0]])[0] / self.config.batch_multiplier, tf.int32) 140 | negative_weights = tf.math.divide_no_nan( 141 | tf.cast(self.config.batch_multiplier * batch_size - n_positives, tf.float32), 142 | tf.cast(batch_size - n_positives, tf.float32), 143 | ) 144 | new_weights = tf.cast(mask, tf.float32) + (1 - tf.cast(mask, tf.float32)) * negative_weights 145 | 146 | def _split_by_label_concatenate_and_truncate(input_tensor): 147 | # takes positive examples and concatenate with negative examples and truncate 148 | # DANGER: if n_positives > batch_size down-sampling is incorrect (do not use pb_50) 149 | return tf.concat( 150 | [ 151 | input_tensor[mask], 152 | input_tensor[tf.math.logical_not(mask)], 153 | ], 154 | 0, 155 | )[:batch_size] 156 | 157 | if "weights" not in inputs: 158 | # add placeholder so logic below applies even if weights aren't present in inputs 159 | inputs["weights"] = tf.ones([tf.shape(inputs[labels[0]])[0], self.config.num_engagements]) 160 | 161 | for tensor in inputs: 162 | if tensor == "weights": 163 | inputs[tensor] = inputs[tensor] * tf.reshape(new_weights, [-1, 1]) 164 | 165 | inputs[tensor] = _split_by_label_concatenate_and_truncate(inputs[tensor]) 166 | 167 | return inputs 168 | 169 | 170 | def build_preprocess(preprocess_config, mode=config_mod.JobMode.TRAIN): 171 | """Builds a preprocess model to apply all preprocessing stages.""" 172 | if mode == config_mod.JobMode.INFERENCE: 173 | logging.info("Not building preprocessors for dataloading since we are in Inference mode.") 174 | return None 175 | 176 | preprocess_models = [] 177 | if preprocess_config.downsample_negatives: 178 | preprocess_models.append(DownsampleNegatives(preprocess_config.downsample_negatives)) 179 | if preprocess_config.truncate_and_slice: 180 | preprocess_models.append(TruncateAndSlice(preprocess_config.truncate_and_slice)) 181 | if preprocess_config.downcast: 182 | preprocess_models.append(DownCast(preprocess_config.downcast)) 183 | if preprocess_config.rectify_labels: 184 | preprocess_models.append(RectifyLabels(preprocess_config.rectify_labels)) 185 | if preprocess_config.extract_features: 186 | preprocess_models.append(ExtractFeatures(preprocess_config.extract_features)) 187 | 188 | if len(preprocess_models) == 0: 189 | raise ValueError("No known preprocessor.") 190 | 191 | class PreprocessModel(tf.keras.Model): 192 | def __init__(self, preprocess_models): 193 | super().__init__() 194 | self.preprocess_models = preprocess_models 195 | 196 | def call(self, inputs, training=None, mask=None): 197 | outputs = inputs 198 | for model in self.preprocess_models: 199 | outputs = model(outputs, training, mask) 200 | return outputs 201 | 202 | if len(preprocess_models) > 1: 203 | logging.warning( 204 | "With multiple preprocessing models, we apply these models in a predefined order. Future works may introduce customized models and orders." 205 | ) 206 | return PreprocessModel(preprocess_models) 207 | -------------------------------------------------------------------------------- /projects/home/recap/data/tfe_parsing.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import json 3 | 4 | from tml.projects.home.recap.data import config as recap_data_config 5 | 6 | from absl import logging 7 | import tensorflow as tf 8 | 9 | 10 | DEFAULTS_MAP = {"int64_list": 0, "float_list": 0.0, "bytes_list": ""} 11 | DTYPE_MAP = {"int64_list": tf.int64, "float_list": tf.float32, "bytes_list": tf.string} 12 | 13 | 14 | def create_tf_example_schema( 15 | data_config: recap_data_config.SegDenseSchema, 16 | segdense_schema, 17 | ): 18 | """Generate schema for deseralizing tf.Example. 19 | 20 | Args: 21 | segdense_schema: List of dicts of segdense features (includes feature_name, dtype, length). 22 | labels: List of strings denoting labels. 23 | 24 | Returns: 25 | A dictionary schema suitable for deserializing tf.Example. 26 | """ 27 | segdense_config = data_config.seg_dense_schema 28 | labels = list(data_config.tasks.keys()) 29 | used_features = ( 30 | segdense_config.features + list(segdense_config.renamed_features.values()) + labels 31 | ) 32 | logging.info(used_features) 33 | 34 | tfe_schema = {} 35 | for entry in segdense_schema: 36 | feature_name = entry["feature_name"] 37 | 38 | if feature_name in used_features: 39 | length = entry["length"] 40 | dtype = entry["dtype"] 41 | 42 | if feature_name in labels: 43 | logging.info(f"Label: feature name is {feature_name} type is {dtype}") 44 | tfe_schema[feature_name] = tf.io.FixedLenFeature( 45 | length, DTYPE_MAP[dtype], DEFAULTS_MAP[dtype] 46 | ) 47 | elif length == -1: 48 | tfe_schema[feature_name] = tf.io.VarLenFeature(DTYPE_MAP[dtype]) 49 | else: 50 | tfe_schema[feature_name] = tf.io.FixedLenFeature( 51 | length, DTYPE_MAP[dtype], [DEFAULTS_MAP[dtype]] * length 52 | ) 53 | for feature_name in used_features: 54 | if feature_name not in tfe_schema: 55 | raise ValueError(f"{feature_name} missing from schema: {segdense_config.schema_path}.") 56 | return tfe_schema 57 | 58 | 59 | @functools.lru_cache(1) 60 | def make_mantissa_mask(mask_length: int) -> tf.Tensor: 61 | """For experimentating with emulating bfloat16 or less precise types.""" 62 | return tf.constant((1 << 32) - (1 << mask_length), dtype=tf.int32) 63 | 64 | 65 | def mask_mantissa(tensor: tf.Tensor, mask_length: int) -> tf.Tensor: 66 | """For experimentating with emulating bfloat16 or less precise types.""" 67 | mask: tf.Tensor = make_mantissa_mask(mask_length) 68 | return tf.bitcast(tf.bitwise.bitwise_and(tf.bitcast(tensor, tf.int32), mask), tensor.dtype) 69 | 70 | 71 | def parse_tf_example( 72 | serialized_example, 73 | tfe_schema, 74 | seg_dense_schema_config, 75 | ): 76 | """Parse serialized tf.Example into dict of tensors. 77 | 78 | Args: 79 | serialized_example: Serialized tf.Example to be parsed. 80 | tfe_schema: Dictionary schema suitable for deserializing tf.Example. 81 | 82 | Returns: 83 | Dictionary of tensors to be used as model input. 84 | """ 85 | inputs = tf.io.parse_example(serialized=serialized_example, features=tfe_schema) 86 | 87 | for new_feature_name, old_feature_name in seg_dense_schema_config.renamed_features.items(): 88 | inputs[new_feature_name] = inputs.pop(old_feature_name) 89 | 90 | # This should not actually be used except for experimentation with low precision floats. 91 | if "mask_mantissa_features" in seg_dense_schema_config: 92 | for feature_name, mask_length in seg_dense_schema_config.mask_mantissa_features.items(): 93 | inputs[feature_name] = mask_mantissa(inputs[feature_name], mask_length) 94 | 95 | # DANGER DANGER: This default seems really scary, and it's only here because it has to be visible 96 | # at TF level. 97 | # We should not return empty tensors if we dont use embeddings. 98 | # Otherwise, it breaks numpy->pt conversion 99 | renamed_keys = list(seg_dense_schema_config.renamed_features.keys()) 100 | for renamed_key in renamed_keys: 101 | if "embedding" in renamed_key and (renamed_key not in inputs): 102 | inputs[renamed_key] = tf.zeros([], tf.float32) 103 | 104 | logging.info(f"parsed example and inputs are {inputs}") 105 | return inputs 106 | 107 | 108 | def get_seg_dense_parse_fn(data_config: recap_data_config.RecapDataConfig): 109 | """Placeholder for seg dense. 110 | 111 | In the future, when we use more seg dense variations, we can change this. 112 | """ 113 | with tf.io.gfile.GFile(data_config.seg_dense_schema.schema_path, "r") as f: 114 | seg_dense_schema = json.load(f)["schema"] 115 | 116 | tf_example_schema = create_tf_example_schema( 117 | data_config, 118 | seg_dense_schema, 119 | ) 120 | 121 | logging.info("***** TF Example Schema *****") 122 | logging.info(tf_example_schema) 123 | 124 | parse = functools.partial( 125 | parse_tf_example, 126 | tfe_schema=tf_example_schema, 127 | seg_dense_schema_config=data_config.seg_dense_schema, 128 | ) 129 | return parse 130 | -------------------------------------------------------------------------------- /projects/home/recap/data/util.py: -------------------------------------------------------------------------------- 1 | from typing import Mapping, Tuple, Union 2 | import torch 3 | import torchrec 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | 8 | def keyed_tensor_from_tensors_dict( 9 | tensor_map: Mapping[str, torch.Tensor] 10 | ) -> "torchrec.KeyedTensor": 11 | """ 12 | Convert a dictionary of torch tensor to torchrec keyed tensor 13 | Args: 14 | tensor_map: 15 | 16 | Returns: 17 | 18 | """ 19 | keys = list(tensor_map.keys()) 20 | # We expect batch size to be first dim. However, if we get a shape [Batch_size], 21 | # KeyedTensor will not find the correct batch_size. So, in those cases we make sure the shape is 22 | # [Batch_size x 1]. 23 | values = [ 24 | tensor_map[key] if len(tensor_map[key].shape) > 1 else torch.unsqueeze(tensor_map[key], -1) 25 | for key in keys 26 | ] 27 | return torchrec.KeyedTensor.from_tensor_list(keys, values) 28 | 29 | 30 | def _compute_jagged_tensor_from_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 31 | if tensor.is_sparse: 32 | x = tensor.coalesce() # Ensure that the indices are ordered. 33 | lengths = torch.bincount(x.indices()[0]) 34 | values = x.values() 35 | else: 36 | values = tensor 37 | lengths = torch.ones(tensor.shape[0], dtype=torch.int32, device=tensor.device) 38 | return values, lengths 39 | 40 | 41 | def jagged_tensor_from_tensor(tensor: torch.Tensor) -> "torchrec.JaggedTensor": 42 | """ 43 | Convert a torch tensor to torchrec jagged tensor. 44 | Note: Currently only support shape of [Batch_size] or [Batch_size x N] for dense tensors. 45 | For sparse tensor the shape of .values() should be [Batch_size] or [Batch_size x N]; the 46 | dense_shape of the sparse tensor can be arbitrary. 47 | Args: 48 | tensor: a torch (sparse) tensor. 49 | Returns: 50 | """ 51 | values, lengths = _compute_jagged_tensor_from_tensor(tensor) 52 | return torchrec.JaggedTensor(values=values, lengths=lengths) 53 | 54 | 55 | def keyed_jagged_tensor_from_tensors_dict( 56 | tensor_map: Mapping[str, torch.Tensor] 57 | ) -> "torchrec.KeyedJaggedTensor": 58 | """ 59 | Convert a dictionary of (sparse) torch tensors to torchrec keyed jagged tensor. 60 | Note: Currently only support shape of [Batch_size] or [Batch_size x 1] for dense tensors. 61 | For sparse tensor the shape of .values() should be [Batch_size] or [Batch_size x 1]; the 62 | dense_shape of the sparse tensor can be arbitrary. 63 | Args: 64 | tensor_map: 65 | 66 | Returns: 67 | 68 | """ 69 | 70 | if not tensor_map: 71 | return torchrec.KeyedJaggedTensor( 72 | keys=[], 73 | values=torch.zeros(0, dtype=torch.int), 74 | lengths=torch.zeros(0, dtype=torch.int), 75 | ) 76 | values = [] 77 | lengths = [] 78 | for tensor in tensor_map.values(): 79 | tensor_val, tensor_len = _compute_jagged_tensor_from_tensor(tensor) 80 | values.append(torch.squeeze(tensor_val)) 81 | lengths.append(tensor_len) 82 | 83 | values = torch.cat(values, axis=0) 84 | lengths = torch.cat(lengths, axis=0) 85 | 86 | return torchrec.KeyedJaggedTensor( 87 | keys=list(tensor_map.keys()), 88 | values=values, 89 | lengths=lengths, 90 | ) 91 | 92 | 93 | def _tf_to_numpy(tf_tensor: tf.Tensor) -> np.ndarray: 94 | return tf_tensor._numpy() # noqa 95 | 96 | 97 | def _dense_tf_to_torch(tensor: tf.Tensor, pin_memory: bool) -> torch.Tensor: 98 | tensor = _tf_to_numpy(tensor) 99 | # Pytorch does not support bfloat16, up cast to float32 to keep the same number of bits on exponent 100 | if tensor.dtype.name == "bfloat16": 101 | tensor = tensor.astype(np.float32) 102 | 103 | tensor = torch.from_numpy(tensor) 104 | if pin_memory: 105 | tensor = tensor.pin_memory() 106 | return tensor 107 | 108 | 109 | def sparse_or_dense_tf_to_torch( 110 | tensor: Union[tf.Tensor, tf.SparseTensor], pin_memory: bool 111 | ) -> torch.Tensor: 112 | if isinstance(tensor, tf.SparseTensor): 113 | tensor = torch.sparse_coo_tensor( 114 | _dense_tf_to_torch(tensor.indices, pin_memory).t(), 115 | _dense_tf_to_torch(tensor.values, pin_memory), 116 | torch.Size(_tf_to_numpy(tensor.dense_shape)), 117 | ) 118 | else: 119 | tensor = _dense_tf_to_torch(tensor, pin_memory) 120 | return tensor 121 | -------------------------------------------------------------------------------- /projects/home/recap/embedding/config.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | import tml.core.config as base_config 3 | from tml.optimizers import config as optimizer_config 4 | 5 | import pydantic 6 | 7 | 8 | class EmbeddingSnapshot(base_config.BaseConfig): 9 | """Configuration for Embedding snapshot""" 10 | 11 | emb_name: str = pydantic.Field( 12 | ..., description="Name of the embedding table from the loaded snapshot" 13 | ) 14 | embedding_snapshot_uri: str = pydantic.Field( 15 | ..., description="Path to torchsnapshot of the embedding" 16 | ) 17 | 18 | 19 | # https://pytorch.org/torchrec/torchrec.modules.html#torchrec.modules.embedding_configs.EmbeddingBagConfig 20 | class EmbeddingBagConfig(base_config.BaseConfig): 21 | """Configuration for EmbeddingBag.""" 22 | 23 | name: str = pydantic.Field(..., description="name of embedding bag") 24 | num_embeddings: int = pydantic.Field(..., description="size of embedding dictionary") 25 | embedding_dim: int = pydantic.Field(..., description="size of each embedding vector") 26 | pretrained: EmbeddingSnapshot = pydantic.Field(None, description="Snapshot properties") 27 | vocab: str = pydantic.Field( 28 | None, description="Directory to parquet files of mapping from entity ID to table index." 29 | ) 30 | 31 | 32 | class EmbeddingOptimizerConfig(base_config.BaseConfig): 33 | learning_rate: optimizer_config.LearningRate = pydantic.Field( 34 | None, description="learning rate scheduler for the EBC" 35 | ) 36 | init_learning_rate: float = pydantic.Field(description="initial learning rate for the EBC") 37 | # NB: Only sgd is supported right now and implicitly. 38 | # FBGemm only supports simple exact_sgd which only takes LR as an argument. 39 | 40 | 41 | class LargeEmbeddingsConfig(base_config.BaseConfig): 42 | """Configuration for EmbeddingBagCollection. 43 | 44 | The tables listed in this config are gathered into a single torchrec EmbeddingBagCollection. 45 | """ 46 | 47 | tables: List[EmbeddingBagConfig] = pydantic.Field(..., description="list of embedding tables") 48 | optimizer: EmbeddingOptimizerConfig 49 | tables_to_log: List[str] = pydantic.Field( 50 | None, description="list of embedding table names that we want to log during training" 51 | ) 52 | 53 | 54 | class StratifierConfig(base_config.BaseConfig): 55 | name: str 56 | index: int 57 | value: int 58 | 59 | 60 | class SmallEmbeddingBagConfig(base_config.BaseConfig): 61 | """Configuration for SmallEmbeddingBag.""" 62 | 63 | name: str = pydantic.Field(..., description="name of embedding bag") 64 | num_embeddings: int = pydantic.Field(..., description="size of embedding dictionary") 65 | embedding_dim: int = pydantic.Field(..., description="size of each embedding vector") 66 | index: int = pydantic.Field(..., description="index in the discrete tensor to look for") 67 | 68 | 69 | class SmallEmbeddingBagConfig(base_config.BaseConfig): 70 | """Configuration for SmallEmbeddingBag.""" 71 | 72 | name: str = pydantic.Field(..., description="name of embedding bag") 73 | num_embeddings: int = pydantic.Field(..., description="size of embedding dictionary") 74 | embedding_dim: int = pydantic.Field(..., description="size of each embedding vector") 75 | index: int = pydantic.Field(..., description="index in the discrete tensor to look for") 76 | 77 | 78 | class SmallEmbeddingsConfig(base_config.BaseConfig): 79 | """Configuration for SmallEmbeddingConfig. 80 | 81 | Here we can use discrete features that already are present in our TFRecords generated using 82 | segdense conversion as "home_recap_2022_discrete__segdense_vals" which are available in 83 | the model as "discrete_features", and embed a user-defined set of them with configurable 84 | dimensions and vocabulary sizes. 85 | 86 | Compared with LargeEmbedding, this config is for small embedding tables that can fit inside 87 | the model, whereas LargeEmbedding usually is meant to be hydrated outside the model at 88 | serving time due to size (>>1 GB). 89 | 90 | This small embeddings table uses the same optimizer as the rest of the model.""" 91 | 92 | tables: List[SmallEmbeddingBagConfig] = pydantic.Field( 93 | ..., description="list of embedding tables" 94 | ) 95 | -------------------------------------------------------------------------------- /projects/home/recap/main.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | from typing import Callable, List, Optional, Tuple 4 | import tensorflow as tf 5 | 6 | import tml.common.checkpointing.snapshot as snapshot_lib 7 | from tml.common.device import setup_and_get_device 8 | from tml.core import config as tml_config_mod 9 | import tml.core.custom_training_loop as ctl 10 | from tml.core import debug_training_loop 11 | from tml.core import losses 12 | from tml.core.loss_type import LossType 13 | from tml.model import maybe_shard_model 14 | 15 | 16 | import tml.projects.home.recap.data.dataset as ds 17 | import tml.projects.home.recap.config as recap_config_mod 18 | import tml.projects.home.recap.optimizer as optimizer_mod 19 | 20 | 21 | # from tml.projects.home.recap import feature 22 | import tml.projects.home.recap.model as model_mod 23 | import torchmetrics as tm 24 | import torch 25 | import torch.distributed as dist 26 | from torchrec.distributed.model_parallel import DistributedModelParallel 27 | 28 | from absl import app, flags, logging 29 | 30 | flags.DEFINE_string("config_path", None, "Path to hyperparameters for model.") 31 | flags.DEFINE_bool("debug_loop", False, "Run with debug loop (slow)") 32 | 33 | FLAGS = flags.FLAGS 34 | 35 | 36 | def run(unused_argv: str, data_service_dispatcher: Optional[str] = None): 37 | print("#" * 100) 38 | 39 | config = tml_config_mod.load_config_from_yaml(recap_config_mod.RecapConfig, FLAGS.config_path) 40 | logging.info("Config: %s", config.pretty_print()) 41 | 42 | device = setup_and_get_device() 43 | 44 | # Always enable tensorfloat on supported devices. 45 | torch.backends.cuda.matmul.allow_tf32 = True 46 | torch.backends.cudnn.allow_tf32 = True 47 | 48 | loss_fn = losses.build_multi_task_loss( 49 | loss_type=LossType.BCE_WITH_LOGITS, 50 | tasks=list(config.model.tasks.keys()), 51 | pos_weights=[task.pos_weight for task in config.model.tasks.values()], 52 | ) 53 | 54 | # Since the prod model doesn't use large embeddings, for now we won't support them. 55 | assert config.model.large_embeddings is None 56 | 57 | train_dataset = ds.RecapDataset( 58 | data_config=config.train_data, 59 | dataset_service=data_service_dispatcher, 60 | mode=recap_config_mod.JobMode.TRAIN, 61 | compression=config.train_data.dataset_service_compression, 62 | vocab_mapper=None, 63 | repeat=True, 64 | ) 65 | 66 | train_iterator = iter(train_dataset.to_dataloader()) 67 | 68 | torch_element_spec = train_dataset.torch_element_spec 69 | 70 | model = model_mod.create_ranking_model( 71 | data_spec=torch_element_spec[0], 72 | config=config, 73 | loss_fn=loss_fn, 74 | device=device, 75 | ) 76 | 77 | optimizer, scheduler = optimizer_mod.build_optimizer(model, config.optimizer, None) 78 | 79 | model = maybe_shard_model(model, device) 80 | 81 | datetime_str = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M") 82 | print(f"{datetime_str}\n", end="") 83 | 84 | if FLAGS.debug_loop: 85 | logging.warning("Running debug mode, slow!") 86 | train_mod = debug_training_loop 87 | else: 88 | train_mod = ctl 89 | 90 | train_mod.train( 91 | model=model, 92 | optimizer=optimizer, 93 | device=device, 94 | save_dir=config.training.save_dir, 95 | logging_interval=config.training.train_log_every_n, 96 | train_steps=config.training.num_train_steps, 97 | checkpoint_frequency=config.training.checkpoint_every_n, 98 | dataset=train_iterator, 99 | worker_batch_size=config.train_data.global_batch_size, 100 | enable_amp=False, 101 | initial_checkpoint_dir=config.training.initial_checkpoint_dir, 102 | gradient_accumulation=config.training.gradient_accumulation, 103 | scheduler=scheduler, 104 | ) 105 | 106 | 107 | if __name__ == "__main__": 108 | app.run(run) 109 | -------------------------------------------------------------------------------- /projects/home/recap/model/__init__.py: -------------------------------------------------------------------------------- 1 | from tml.projects.home.recap.model.entrypoint import ( 2 | create_ranking_model, 3 | sanitize, 4 | unsanitize, 5 | MultiTaskRankingModel, 6 | ) 7 | from tml.projects.home.recap.model.model_and_loss import ModelAndLoss 8 | -------------------------------------------------------------------------------- /projects/home/recap/model/feature_transform.py: -------------------------------------------------------------------------------- 1 | from typing import Mapping, Sequence, Union 2 | 3 | from tml.projects.home.recap.model.config import ( 4 | BatchNormConfig, 5 | DoubleNormLogConfig, 6 | FeaturizationConfig, 7 | LayerNormConfig, 8 | ) 9 | 10 | import torch 11 | 12 | 13 | def log_transform(x: torch.Tensor) -> torch.Tensor: 14 | """Safe log transform that works across both negative, zero, and positive floats.""" 15 | return torch.sign(x) * torch.log1p(torch.abs(x)) 16 | 17 | 18 | class BatchNorm(torch.nn.Module): 19 | def __init__(self, num_features: int, config: BatchNormConfig): 20 | super().__init__() 21 | self.layer = torch.nn.BatchNorm1d(num_features, affine=config.affine, momentum=config.momentum) 22 | 23 | def forward(self, x: torch.Tensor) -> torch.Tensor: 24 | return self.layer(x) 25 | 26 | 27 | class LayerNorm(torch.nn.Module): 28 | def __init__(self, normalized_shape: Union[int, Sequence[int]], config: LayerNormConfig): 29 | super().__init__() 30 | if config.axis != -1: 31 | raise NotImplementedError 32 | if config.center != config.scale: 33 | raise ValueError( 34 | f"Center and scale must match in torch, received {config.center}, {config.scale}" 35 | ) 36 | self.layer = torch.nn.LayerNorm( 37 | normalized_shape, eps=config.epsilon, elementwise_affine=config.center 38 | ) 39 | 40 | def forward(self, x: torch.Tensor) -> torch.Tensor: 41 | return self.layer(x) 42 | 43 | 44 | class Log1pAbs(torch.nn.Module): 45 | def __init__(self): 46 | super().__init__() 47 | 48 | def forward(self, x: torch.Tensor) -> torch.Tensor: 49 | return log_transform(x) 50 | 51 | 52 | class InputNonFinite(torch.nn.Module): 53 | def __init__(self, fill_value: float = 0): 54 | super().__init__() 55 | 56 | self.register_buffer( 57 | "fill_value", torch.as_tensor(fill_value, dtype=torch.float32), persistent=False 58 | ) 59 | 60 | def forward(self, x: torch.Tensor) -> torch.Tensor: 61 | return torch.where(torch.isfinite(x), x, self.fill_value) 62 | 63 | 64 | class Clamp(torch.nn.Module): 65 | def __init__(self, min_value: float, max_value: float): 66 | super().__init__() 67 | # Using buffer to make sure they are on correct device (and not moved every time). 68 | # Will also be part of state_dict. 69 | self.register_buffer( 70 | "min_value", torch.as_tensor(min_value, dtype=torch.float32), persistent=True 71 | ) 72 | self.register_buffer( 73 | "max_value", torch.as_tensor(max_value, dtype=torch.float32), persistent=True 74 | ) 75 | 76 | def forward(self, x: torch.Tensor) -> torch.Tensor: 77 | return torch.clamp(x, min=self.min_value, max=self.max_value) 78 | 79 | 80 | class DoubleNormLog(torch.nn.Module): 81 | """Performs a batch norm and clamp on continuous features followed by a layer norm on binary and continuous features.""" 82 | 83 | def __init__( 84 | self, 85 | input_shapes: Mapping[str, Sequence[int]], 86 | config: DoubleNormLogConfig, 87 | ): 88 | super().__init__() 89 | 90 | _before_concat_layers = [ 91 | InputNonFinite(), 92 | Log1pAbs(), 93 | ] 94 | if config.batch_norm_config: 95 | _before_concat_layers.append( 96 | BatchNorm(input_shapes["continuous"][-1], config.batch_norm_config) 97 | ) 98 | _before_concat_layers.append( 99 | Clamp(min_value=-config.clip_magnitude, max_value=config.clip_magnitude) 100 | ) 101 | self._before_concat_layers = torch.nn.Sequential(*_before_concat_layers) 102 | 103 | self.layer_norm = None 104 | if config.layer_norm_config: 105 | last_dim = input_shapes["continuous"][-1] + input_shapes["binary"][-1] 106 | self.layer_norm = LayerNorm(last_dim, config.layer_norm_config) 107 | 108 | def forward( 109 | self, continuous_features: torch.Tensor, binary_features: torch.Tensor 110 | ) -> torch.Tensor: 111 | x = self._before_concat_layers(continuous_features) 112 | x = torch.cat([x, binary_features], dim=1) 113 | if self.layer_norm: 114 | return self.layer_norm(x) 115 | return x 116 | 117 | 118 | def build_features_preprocessor( 119 | config: FeaturizationConfig, input_shapes: Mapping[str, Sequence[int]] 120 | ): 121 | """Trivial right now, but we will change in the future.""" 122 | return DoubleNormLog(input_shapes, config.double_norm_log_config) 123 | -------------------------------------------------------------------------------- /projects/home/recap/model/mask_net.py: -------------------------------------------------------------------------------- 1 | """MaskNet: Wang et al. (https://arxiv.org/abs/2102.07619).""" 2 | 3 | from tml.projects.home.recap.model import config, mlp 4 | 5 | import torch 6 | 7 | 8 | def _init_weights(module): 9 | if isinstance(module, torch.nn.Linear): 10 | torch.nn.init.xavier_uniform_(module.weight) 11 | torch.nn.init.constant_(module.bias, 0) 12 | 13 | 14 | class MaskBlock(torch.nn.Module): 15 | def __init__( 16 | self, mask_block_config: config.MaskBlockConfig, input_dim: int, mask_input_dim: int 17 | ) -> None: 18 | super(MaskBlock, self).__init__() 19 | self.mask_block_config = mask_block_config 20 | output_size = mask_block_config.output_size 21 | 22 | if mask_block_config.input_layer_norm: 23 | self._input_layer_norm = torch.nn.LayerNorm(input_dim) 24 | else: 25 | self._input_layer_norm = None 26 | 27 | if mask_block_config.reduction_factor: 28 | aggregation_size = int(mask_input_dim * mask_block_config.reduction_factor) 29 | elif mask_block_config.aggregation_size is not None: 30 | aggregation_size = mask_block_config.aggregation_size 31 | else: 32 | raise ValueError("Need one of reduction factor or aggregation size.") 33 | 34 | self._mask_layer = torch.nn.Sequential( 35 | torch.nn.Linear(mask_input_dim, aggregation_size), 36 | torch.nn.ReLU(), 37 | torch.nn.Linear(aggregation_size, input_dim), 38 | ) 39 | self._mask_layer.apply(_init_weights) 40 | self._hidden_layer = torch.nn.Linear(input_dim, output_size) 41 | self._hidden_layer.apply(_init_weights) 42 | self._layer_norm = torch.nn.LayerNorm(output_size) 43 | 44 | def forward(self, net: torch.Tensor, mask_input: torch.Tensor): 45 | if self._input_layer_norm: 46 | net = self._input_layer_norm(net) 47 | hidden_layer_output = self._hidden_layer(net * self._mask_layer(mask_input)) 48 | return self._layer_norm(hidden_layer_output) 49 | 50 | 51 | class MaskNet(torch.nn.Module): 52 | def __init__(self, mask_net_config: config.MaskNetConfig, in_features: int): 53 | super().__init__() 54 | self.mask_net_config = mask_net_config 55 | mask_blocks = [] 56 | 57 | if mask_net_config.use_parallel: 58 | total_output_mask_blocks = 0 59 | for mask_block_config in mask_net_config.mask_blocks: 60 | mask_blocks.append(MaskBlock(mask_block_config, in_features, in_features)) 61 | total_output_mask_blocks += mask_block_config.output_size 62 | self._mask_blocks = torch.nn.ModuleList(mask_blocks) 63 | else: 64 | input_size = in_features 65 | for mask_block_config in mask_net_config.mask_blocks: 66 | mask_blocks.append(MaskBlock(mask_block_config, input_size, in_features)) 67 | input_size = mask_block_config.output_size 68 | 69 | self._mask_blocks = torch.nn.ModuleList(mask_blocks) 70 | total_output_mask_blocks = mask_block_config.output_size 71 | 72 | if mask_net_config.mlp: 73 | self._dense_layers = mlp.Mlp(total_output_mask_blocks, mask_net_config.mlp) 74 | self.out_features = mask_net_config.mlp.layer_sizes[-1] 75 | else: 76 | self.out_features = total_output_mask_blocks 77 | self.shared_size = total_output_mask_blocks 78 | 79 | def forward(self, inputs: torch.Tensor): 80 | if self.mask_net_config.use_parallel: 81 | mask_outputs = [] 82 | for mask_layer in self._mask_blocks: 83 | mask_outputs.append(mask_layer(mask_input=inputs, net=inputs)) 84 | # Share the outputs of the MaskBlocks. 85 | all_mask_outputs = torch.cat(mask_outputs, dim=1) 86 | output = ( 87 | all_mask_outputs 88 | if self.mask_net_config.mlp is None 89 | else self._dense_layers(all_mask_outputs)["output"] 90 | ) 91 | return {"output": output, "shared_layer": all_mask_outputs} 92 | else: 93 | net = inputs 94 | for mask_layer in self._mask_blocks: 95 | net = mask_layer(net=net, mask_input=inputs) 96 | # Share the output of the stacked MaskBlocks. 97 | output = net if self.mask_net_config.mlp is None else self._dense_layers[net]["output"] 98 | return {"output": output, "shared_layer": net} 99 | -------------------------------------------------------------------------------- /projects/home/recap/model/mlp.py: -------------------------------------------------------------------------------- 1 | """MLP feed forward stack in torch.""" 2 | 3 | from tml.projects.home.recap.model.config import MlpConfig 4 | 5 | import torch 6 | from absl import logging 7 | 8 | 9 | def _init_weights(module): 10 | if isinstance(module, torch.nn.Linear): 11 | torch.nn.init.xavier_uniform_(module.weight) 12 | torch.nn.init.constant_(module.bias, 0) 13 | 14 | 15 | class Mlp(torch.nn.Module): 16 | def __init__(self, in_features: int, mlp_config: MlpConfig): 17 | super().__init__() 18 | self._mlp_config = mlp_config 19 | input_size = in_features 20 | layer_sizes = mlp_config.layer_sizes 21 | modules = [] 22 | for layer_size in layer_sizes[:-1]: 23 | modules.append(torch.nn.Linear(input_size, layer_size, bias=True)) 24 | 25 | if mlp_config.batch_norm: 26 | modules.append( 27 | torch.nn.BatchNorm1d( 28 | layer_size, affine=mlp_config.batch_norm.affine, momentum=mlp_config.batch_norm.momentum 29 | ) 30 | ) 31 | 32 | modules.append(torch.nn.ReLU()) 33 | 34 | if mlp_config.dropout: 35 | modules.append(torch.nn.Dropout(mlp_config.dropout.rate)) 36 | 37 | input_size = layer_size 38 | modules.append(torch.nn.Linear(input_size, layer_sizes[-1], bias=True)) 39 | if mlp_config.final_layer_activation: 40 | modules.append(torch.nn.ReLU()) 41 | self.layers = torch.nn.ModuleList(modules) 42 | self.layers.apply(_init_weights) 43 | 44 | def forward(self, x: torch.Tensor) -> torch.Tensor: 45 | net = x 46 | for i, layer in enumerate(self.layers): 47 | net = layer(net) 48 | if i == 1: # Share the first (widest?) set of activations for other applications. 49 | shared_layer = net 50 | return {"output": net, "shared_layer": shared_layer} 51 | 52 | @property 53 | def shared_size(self): 54 | return self._mlp_config.layer_sizes[-1] 55 | 56 | @property 57 | def out_features(self): 58 | return self._mlp_config.layer_sizes[-1] 59 | -------------------------------------------------------------------------------- /projects/home/recap/model/model_and_loss.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional, List 2 | from tml.projects.home.recap.embedding import config as embedding_config_mod 3 | import torch 4 | from absl import logging 5 | 6 | 7 | class ModelAndLoss(torch.nn.Module): 8 | def __init__( 9 | self, 10 | model, 11 | loss_fn: Callable, 12 | stratifiers: Optional[List[embedding_config_mod.StratifierConfig]] = None, 13 | ) -> None: 14 | """ 15 | Args: 16 | model: torch module to wrap. 17 | loss_fn: Function for calculating loss, should accept logits and labels. 18 | straitifiers: mapping of stratifier name and index of discrete features to emit for metrics stratification. 19 | """ 20 | super().__init__() 21 | self.model = model 22 | self.loss_fn = loss_fn 23 | self.stratifiers = stratifiers 24 | 25 | def forward(self, batch: "RecapBatch"): # type: ignore[name-defined] 26 | """Runs model forward and calculates loss according to given loss_fn. 27 | 28 | NOTE: The input signature here needs to be a Pipelineable object for 29 | prefetching purposes during training using torchrec's pipeline. However 30 | the underlying model signature needs to be exportable to onnx, requiring 31 | generic python types. see https://pytorch.org/docs/stable/onnx.html#types. 32 | 33 | """ 34 | outputs = self.model( 35 | continuous_features=batch.continuous_features, 36 | binary_features=batch.binary_features, 37 | discrete_features=batch.discrete_features, 38 | sparse_features=batch.sparse_features, 39 | user_embedding=batch.user_embedding, 40 | user_eng_embedding=batch.user_eng_embedding, 41 | author_embedding=batch.author_embedding, 42 | labels=batch.labels, 43 | weights=batch.weights, 44 | ) 45 | losses = self.loss_fn(outputs["logits"], batch.labels.float(), batch.weights.float()) 46 | 47 | if self.stratifiers: 48 | logging.info(f"***** Adding stratifiers *****\n {self.stratifiers}") 49 | outputs["stratifiers"] = {} 50 | for stratifier in self.stratifiers: 51 | outputs["stratifiers"][stratifier.name] = batch.discrete_features[:, stratifier.index] 52 | 53 | # In general, we can have a large number of losses returned by our loss function. 54 | if isinstance(losses, dict): 55 | return losses["loss"], { 56 | **outputs, 57 | **losses, 58 | "labels": batch.labels, 59 | "weights": batch.weights, 60 | } 61 | else: # Assume that this is a float. 62 | return losses, { 63 | **outputs, 64 | "loss": losses, 65 | "labels": batch.labels, 66 | "weights": batch.weights, 67 | } 68 | -------------------------------------------------------------------------------- /projects/home/recap/model/numeric_calibration.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class NumericCalibration(torch.nn.Module): 5 | def __init__( 6 | self, 7 | pos_downsampling_rate: float, 8 | neg_downsampling_rate: float, 9 | ): 10 | super().__init__() 11 | 12 | # Using buffer to make sure they are on correct device (and not moved every time). 13 | # Will also be part of state_dict. 14 | self.register_buffer( 15 | "ratio", torch.as_tensor(neg_downsampling_rate / pos_downsampling_rate), persistent=True 16 | ) 17 | 18 | def forward(self, probs: torch.Tensor): 19 | return probs * self.ratio / (1.0 - probs + (self.ratio * probs)) 20 | -------------------------------------------------------------------------------- /projects/home/recap/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | from tml.projects.home.recap.optimizer.optimizer import build_optimizer 2 | -------------------------------------------------------------------------------- /projects/home/recap/optimizer/config.py: -------------------------------------------------------------------------------- 1 | """Optimization configurations for models.""" 2 | 3 | import typing 4 | 5 | import tml.core.config as base_config 6 | import tml.optimizers.config as optimizers_config_mod 7 | 8 | import pydantic 9 | 10 | 11 | class RecapAdamConfig(base_config.BaseConfig): 12 | beta_1: float = 0.9 # Momentum term. 13 | beta_2: float = 0.999 # Exponential weighted decay factor. 14 | epsilon: float = 1e-7 # Numerical stability in denominator. 15 | 16 | 17 | class MultiTaskLearningRates(base_config.BaseConfig): 18 | tower_learning_rates: typing.Dict[str, optimizers_config_mod.LearningRate] = pydantic.Field( 19 | description="Learning rates for different towers of the model." 20 | ) 21 | 22 | backbone_learning_rate: optimizers_config_mod.LearningRate = pydantic.Field( 23 | None, description="Learning rate for backbone of the model." 24 | ) 25 | 26 | 27 | class RecapOptimizerConfig(base_config.BaseConfig): 28 | multi_task_learning_rates: MultiTaskLearningRates = pydantic.Field( 29 | None, description="Multiple learning rates for different tasks.", one_of="lr" 30 | ) 31 | 32 | single_task_learning_rate: optimizers_config_mod.LearningRate = pydantic.Field( 33 | None, description="Single task learning rates", one_of="lr" 34 | ) 35 | 36 | adam: RecapAdamConfig = pydantic.Field(one_of="optimizer") 37 | -------------------------------------------------------------------------------- /projects/home/recap/optimizer/optimizer.py: -------------------------------------------------------------------------------- 1 | """Build optimizers and learning rate schedules.""" 2 | import bisect 3 | from collections import defaultdict 4 | import functools 5 | import math 6 | import typing 7 | from typing import Optional 8 | import warnings 9 | 10 | # from large_embeddings.config import EmbeddingOptimizerConfig 11 | from tml.projects.home.recap import model as model_mod 12 | from tml.optimizers import config 13 | from tml.optimizers import compute_lr 14 | from absl import logging # type: ignore[attr-defined] 15 | 16 | import torch 17 | from torchrec.optim import keyed 18 | 19 | 20 | _DEFAULT_LR = 24601.0 # NaN the model if we're not using the learning rate. 21 | _BACKBONE = "backbone" 22 | _DENSE_EMBEDDINGS = "dense_ebc" 23 | 24 | 25 | class RecapLRShim(torch.optim.lr_scheduler._LRScheduler): 26 | """Shim to get learning rates into a LRScheduler. 27 | 28 | This adheres to the torch.optim scheduler API and can be plugged anywhere that 29 | e.g. exponential decay can be used. 30 | 31 | """ 32 | 33 | def __init__( 34 | self, 35 | optimizer, 36 | lr_dict: typing.Dict[str, config.LearningRate], 37 | emb_learning_rate, 38 | last_epoch=-1, 39 | verbose=False, 40 | ): 41 | self.optimizer = optimizer 42 | self.lr_dict = lr_dict 43 | self.group_names = list(self.lr_dict.keys()) 44 | self.emb_learning_rate = emb_learning_rate 45 | 46 | # We handle sparse LR scheduling separately, so only validate LR groups against dense param groups 47 | num_dense_param_groups = sum( 48 | 1 49 | for _, _optim in optimizer._optims 50 | for _ in _optim.param_groups 51 | if isinstance(_optim, keyed.KeyedOptimizerWrapper) 52 | ) 53 | if num_dense_param_groups != len(lr_dict): 54 | raise ValueError( 55 | f"Optimizer had {len(optimizer.param_groups)}, but config had {len(lr_dict)}." 56 | ) 57 | super().__init__(optimizer, last_epoch, verbose) 58 | 59 | def get_lr(self): 60 | if not self._get_lr_called_within_step: 61 | warnings.warn( 62 | "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", 63 | UserWarning, 64 | ) 65 | return self._get_closed_form_lr() 66 | 67 | def _get_closed_form_lr(self): 68 | learning_rates = [] 69 | 70 | for lr_config in self.lr_dict.values(): 71 | learning_rates.append(compute_lr(lr_config, self.last_epoch)) 72 | # WARNING: The order of appending is important. 73 | if self.emb_learning_rate: 74 | learning_rates.append(compute_lr(self.emb_learning_rate, self.last_epoch)) 75 | return learning_rates 76 | 77 | 78 | def build_optimizer( 79 | model: torch.nn.Module, 80 | optimizer_config: config.OptimizerConfig, 81 | emb_optimizer_config: None = None, # Optional[EmbeddingOptimizerConfig] = None, 82 | ): 83 | """Builds an optimizer and scheduler. 84 | 85 | Args: 86 | model: A torch model, probably with DDP/DMP. 87 | optimizer_config: An OptimizerConfig object that specifies learning rates per tower. 88 | 89 | Returns: 90 | A torch.optim instance, and a scheduler instance. 91 | """ 92 | optimizer_fn = functools.partial( 93 | torch.optim.Adam, 94 | lr=_DEFAULT_LR, 95 | betas=(optimizer_config.adam.beta_1, optimizer_config.adam.beta_2), 96 | eps=optimizer_config.adam.epsilon, 97 | maximize=False, 98 | ) 99 | if optimizer_config.multi_task_learning_rates: 100 | logging.info("***** Parameter groups for optimization *****") 101 | # Importantly, we preserve insertion order in dictionaries here. 102 | parameter_groups: typing.Dict[str, typing.Dict] = defaultdict(dict) 103 | added_parameters: typing.Set[str] = set() 104 | for task in optimizer_config.multi_task_learning_rates.tower_learning_rates: 105 | for name, parameter in model.named_parameters(): 106 | if f".{model_mod.sanitize(task)}." in name: 107 | parameter_groups[task][name] = parameter 108 | logging.info(f"{task}: {name}") 109 | if name in added_parameters: 110 | raise ValueError(f"Parameter {name} matched multiple tasks.") 111 | added_parameters.add(name) 112 | 113 | for name, parameter in model.named_parameters(): 114 | if name not in added_parameters and "embedding_bags" not in name: 115 | parameter_groups[_BACKBONE][name] = parameter 116 | added_parameters.add(name) 117 | logging.info(f"{_BACKBONE}: {name}") 118 | 119 | for name, parameter in model.named_parameters(): 120 | if name not in added_parameters and "embedding_bags" in name: 121 | parameter_groups[_DENSE_EMBEDDINGS][name] = parameter 122 | logging.info(f"{_DENSE_EMBEDDINGS}: {name}") 123 | 124 | all_learning_rates = optimizer_config.multi_task_learning_rates.tower_learning_rates.copy() 125 | if optimizer_config.multi_task_learning_rates.backbone_learning_rate is not None: 126 | all_learning_rates[ 127 | _BACKBONE 128 | ] = optimizer_config.multi_task_learning_rates.backbone_learning_rate 129 | if _DENSE_EMBEDDINGS in parameter_groups and emb_optimizer_config: 130 | all_learning_rates[_DENSE_EMBEDDINGS] = emb_optimizer_config.learning_rate.copy() 131 | else: 132 | parameter_groups = dict(model.named_parameters()) 133 | all_learning_rates = {"single_task": optimizer_config.single_task_learning_rate} 134 | 135 | optimizers = [ 136 | keyed.KeyedOptimizerWrapper(param_group, optimizer_fn) 137 | for param_name, param_group in parameter_groups.items() 138 | if param_name != _DENSE_EMBEDDINGS 139 | ] 140 | # Making EBC optimizer to be SGD to match fused optimiser 141 | if _DENSE_EMBEDDINGS in parameter_groups: 142 | optimizers.append( 143 | keyed.KeyedOptimizerWrapper( 144 | parameter_groups[_DENSE_EMBEDDINGS], 145 | functools.partial(torch.optim.SGD, lr=_DEFAULT_LR, maximize=False, momentum=False), 146 | ) 147 | ) 148 | 149 | if not parameter_groups.keys() == all_learning_rates.keys(): 150 | raise ValueError("Learning rates do not match optimizers") 151 | 152 | # If the optimiser is dense, model.fused_optimizer will be empty (but not None) 153 | emb_learning_rate = None 154 | if hasattr(model, "fused_optimizer") and model.fused_optimizer.optimizers: 155 | logging.info(f"Model fused optimiser: {model.fused_optimizer}") 156 | optimizers.append(model.fused_optimizer) 157 | if emb_optimizer_config: 158 | emb_learning_rate = emb_optimizer_config.learning_rate.copy() 159 | else: 160 | raise ValueError("Fused kernel exists, but LR is not set") 161 | logging.info(f"***** Combining optimizers: {optimizers} *****") 162 | optimizer = keyed.CombinedOptimizer(optimizers) 163 | scheduler = RecapLRShim(optimizer, all_learning_rates, emb_learning_rate) 164 | logging.info(f"***** Combined optimizer after init: {optimizer} *****") 165 | 166 | return optimizer, scheduler 167 | -------------------------------------------------------------------------------- /projects/home/recap/script/create_random_data.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Runs from inside venv 4 | 5 | rm -rf $HOME/tmp/runs/recap_local_random_data 6 | python -m tml.machines.is_venv || exit 1 7 | export TML_BASE="$(git rev-parse --show-toplevel)" 8 | 9 | mkdir -p $HOME/tmp/recap_local_random_data 10 | python projects/home/recap/data/generate_random_data.py --config_path $(pwd)/projects/home/recap/config/local_prod.yaml 11 | -------------------------------------------------------------------------------- /projects/home/recap/script/run_local.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Runs from inside venv 4 | rm -rf $HOME/tmp/runs/recap_local_debug 5 | mkdir -p $HOME/tmp/runs/recap_local_debug 6 | python -m tml.machines.is_venv || exit 1 7 | export TML_BASE="$(git rev-parse --show-toplevel)" 8 | 9 | torchrun \ 10 | --standalone \ 11 | --nnodes 1 \ 12 | --nproc_per_node 1 \ 13 | projects/home/recap/main.py \ 14 | --config_path $(pwd)/projects/home/recap/config/local_prod.yaml \ 15 | $@ 16 | -------------------------------------------------------------------------------- /projects/twhin/README.md: -------------------------------------------------------------------------------- 1 | Twhin in torchrec 2 | 3 | This project contains code for pretraining dense vector embedding features for Twitter entities. Within Twitter, these embeddings are used for candidate retrieval and as model features in a variety of recommender system models. 4 | 5 | We obtain entity embeddings based on a variety of graph data within Twitter such as: 6 | "User follows User" 7 | "User favorites Tweet" 8 | "User clicks Advertisement" 9 | 10 | While we cannot release the graph data used to train TwHIN embeddings due to privacy restrictions, heavily subsampled, anonymized open-sourced graph data can used: 11 | https://huggingface.co/datasets/Twitter/TwitterFollowGraph 12 | https://huggingface.co/datasets/Twitter/TwitterFaveGraph 13 | 14 | The code expects parquet files with three columns: lhs, rel, rhs that refer to the vocab index of the left-hand-side node, relation type, and right-hand-side node of each edge in a graph respectively. 15 | 16 | The location of the data must be specified in the configuration yaml files in projects/twhin/configs. 17 | 18 | 19 | Workflow 20 | ======== 21 | - Build local development images `./scripts/build_images.sh` 22 | - Run with `./scripts/docker_run.sh` 23 | - Iterate in image with `./scripts/idocker.sh` 24 | - Run tests with `./scripts/docker_test.sh` 25 | -------------------------------------------------------------------------------- /projects/twhin/config.py: -------------------------------------------------------------------------------- 1 | from tml.core.config import base_config 2 | from tml.projects.twhin.data.config import TwhinDataConfig 3 | from tml.projects.twhin.models.config import TwhinModelConfig 4 | from tml.core.config.training import RuntimeConfig, TrainingConfig 5 | 6 | import pydantic 7 | 8 | 9 | class TwhinConfig(base_config.BaseConfig): 10 | runtime: RuntimeConfig = pydantic.Field(RuntimeConfig()) 11 | training: TrainingConfig = pydantic.Field(TrainingConfig()) 12 | model: TwhinModelConfig 13 | train_data: TwhinDataConfig 14 | validation_data: TwhinDataConfig 15 | -------------------------------------------------------------------------------- /projects/twhin/config/local.yaml: -------------------------------------------------------------------------------- 1 | runtime: 2 | enable_amp: false 3 | training: 4 | save_dir: "/tmp/model" 5 | num_train_steps: 100000 6 | checkpoint_every_n: 100000 7 | train_log_every_n: 10 8 | num_eval_steps: 1000 9 | eval_log_every_n: 500 10 | eval_timeout_in_s: 10000 11 | num_epochs: 5 12 | model: 13 | translation_optimizer: 14 | sgd: 15 | lr: 0.05 16 | learning_rate: 17 | constant: 0.05 18 | embeddings: 19 | tables: 20 | - name: user 21 | num_embeddings: 424_241 22 | embedding_dim: 4 23 | data_type: fp32 24 | optimizer: 25 | sgd: 26 | lr: 0.01 27 | learning_rate: 28 | constant: 0.01 29 | - name: tweet 30 | num_embeddings: 72_543 31 | embedding_dim: 4 32 | data_type: fp32 33 | optimizer: 34 | sgd: 35 | lr: 0.005 36 | learning_rate: 37 | constant: 0.005 38 | relations: 39 | - name: fav 40 | lhs: user 41 | rhs: tweet 42 | operator: translation 43 | - name: reply 44 | lhs: user 45 | rhs: tweet 46 | operator: translation 47 | - name: retweet 48 | lhs: user 49 | rhs: tweet 50 | operator: translation 51 | - name: magic_recs 52 | lhs: user 53 | rhs: tweet 54 | operator: translation 55 | train_data: 56 | data_root: "gs://follows_tml_01/tweet_eng/2023-01-23/large/edges/*" 57 | per_replica_batch_size: 500 58 | global_negatives: 0 59 | in_batch_negatives: 10 60 | limit: 9990 61 | validation_data: 62 | data_root: "gs://follows_tml_01/tweet_eng/2023-01-23/large/edges/*" 63 | per_replica_batch_size: 500 64 | global_negatives: 0 65 | in_batch_negatives: 10 66 | limit: 10 67 | offset: 9990 68 | -------------------------------------------------------------------------------- /projects/twhin/data/config.py: -------------------------------------------------------------------------------- 1 | from tml.core.config import base_config 2 | 3 | import pydantic 4 | 5 | 6 | class TwhinDataConfig(base_config.BaseConfig): 7 | data_root: str 8 | per_replica_batch_size: pydantic.PositiveInt 9 | global_negatives: int 10 | in_batch_negatives: int 11 | limit: pydantic.PositiveInt 12 | offset: pydantic.PositiveInt = pydantic.Field( 13 | None, description="The offset to start reading from." 14 | ) 15 | -------------------------------------------------------------------------------- /projects/twhin/data/data.py: -------------------------------------------------------------------------------- 1 | from tml.projects.twhin.data.config import TwhinDataConfig 2 | from tml.projects.twhin.models.config import TwhinModelConfig 3 | from tml.projects.twhin.data.edges import EdgesDataset 4 | 5 | 6 | def create_dataset(data_config: TwhinDataConfig, model_config: TwhinModelConfig): 7 | tables = model_config.embeddings.tables 8 | table_sizes = {table.name: table.num_embeddings for table in tables} 9 | relations = model_config.relations 10 | 11 | pos_batch_size = data_config.per_replica_batch_size 12 | 13 | return EdgesDataset( 14 | file_pattern=data_config.data_root, 15 | relations=relations, 16 | table_sizes=table_sizes, 17 | batch_size=pos_batch_size, 18 | ) 19 | -------------------------------------------------------------------------------- /projects/twhin/data/edges.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from dataclasses import dataclass 3 | from typing import Dict, List, Tuple 4 | 5 | from tml.common.batch import DataclassBatch 6 | from tml.reader.dataset import Dataset 7 | from tml.projects.twhin.models.config import Relation 8 | 9 | import numpy as np 10 | import pyarrow as pa 11 | import pyarrow.compute as pc 12 | import torch 13 | from torchrec.sparse.jagged_tensor import KeyedJaggedTensor 14 | 15 | 16 | @dataclass 17 | class EdgeBatch(DataclassBatch): 18 | nodes: KeyedJaggedTensor 19 | labels: torch.Tensor 20 | rels: torch.Tensor 21 | weights: torch.Tensor 22 | 23 | 24 | class EdgesDataset(Dataset): 25 | rng = np.random.default_rng() 26 | 27 | def __init__( 28 | self, 29 | file_pattern: str, 30 | table_sizes: Dict[str, int], 31 | relations: List[Relation], 32 | lhs_column_name: str = "lhs", 33 | rhs_column_name: str = "rhs", 34 | rel_column_name: str = "rel", 35 | **dataset_kwargs 36 | ): 37 | self.batch_size = dataset_kwargs["batch_size"] 38 | 39 | self.table_sizes = table_sizes 40 | self.num_tables = len(table_sizes) 41 | self.table_names = list(table_sizes.keys()) 42 | 43 | self.relations = relations 44 | self.relations_t = torch.tensor( 45 | [ 46 | [self.table_names.index(relation.lhs), self.table_names.index(relation.rhs)] 47 | for relation in relations 48 | ] 49 | ) 50 | 51 | self.lhs_column_name = lhs_column_name 52 | self.rhs_column_name = rhs_column_name 53 | self.rel_column_name = rel_column_name 54 | self.label_column_name = "label" 55 | 56 | super().__init__(file_pattern=file_pattern, **dataset_kwargs) 57 | 58 | def pa_to_batch(self, batch: pa.RecordBatch): 59 | lhs = torch.from_numpy(batch.column(self.lhs_column_name).to_numpy()) 60 | rhs = torch.from_numpy(batch.column(self.rhs_column_name).to_numpy()) 61 | rel = torch.from_numpy(batch.column(self.rel_column_name).to_numpy()) 62 | label = torch.from_numpy(batch.column(self.label_column_name).to_numpy()) 63 | 64 | nodes = self._to_kjt(lhs, rhs, rel) 65 | return EdgeBatch( 66 | nodes=nodes, 67 | rels=rel, 68 | labels=label, 69 | weights=torch.ones(batch.num_rows), 70 | ) 71 | 72 | def _to_kjt( 73 | self, lhs: torch.Tensor, rhs: torch.Tensor, rel: torch.Tensor 74 | ) -> Tuple[KeyedJaggedTensor, List[Tuple[int, int]]]: 75 | 76 | """Process edges that contain lhs index, rhs index, relation index. 77 | Example: 78 | 79 | ``` 80 | tables = ["f0", "f1", "f2", "f3"] 81 | relations = [["f0", "f1"], ["f1", "f2"], ["f1", "f0"], ["f2", "f1"], ["f0", "f2"]] 82 | self.relations_t = torch.Tensor([[0, 1], [1, 2], [1, 0], [2, 1], [0, 2]]) 83 | lhs = [1, 6, 3, 1, 8] 84 | rhs = [6, 3, 4, 4, 9] 85 | rel = [0, 2, 1, 3, 4] 86 | 87 | This corresponds to the following "edges": 88 | edges = [ 89 | {"lhs": 1, "rhs": 6, "relation": ["f0", "f1"]}, 90 | {"lhs": 6, "rhs": 3, "relation": ["f1", "f0"]}, 91 | {"lhs": 3, "rhs": 4, "relation": ["f1", "f2"]}, 92 | {"lhs": 1, "rhs": 4, "relation": ["f2", "f1"]}, 93 | {"lhs": 8, "rhs": 9, "relation": ["f0", "f2"]}, 94 | ] 95 | ``` 96 | 97 | Returns a KeyedJaggedTensor used to look up all embeddings. 98 | 99 | Note: We treat the lhs and rhs as though they're separate lookups: `len(lenghts) == 2 * bsz * len(tables)`. 100 | This differs from the DLRM pattern where we have `len(lengths) = bsz * len(tables)`. 101 | 102 | For the example above: 103 | ``` 104 | lookups = tensor([ 105 | [0., 1.], 106 | [1., 6.], 107 | [1., 6.], 108 | [0., 3.], 109 | [1., 3.], 110 | [2., 4.], 111 | [2., 1.], 112 | [1., 4.], 113 | [0., 8.], 114 | [2., 9.] 115 | ]) 116 | 117 | kjt = KeyedJaggedTensor( 118 | features=["f0", "f1", "f2"] 119 | values=[ 120 | 1, 3, 8, # f0 121 | 6, 6, 3, 4, # f1 122 | 4, 1, 9 # f2 123 | ] 124 | lengths=[ 125 | 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, # f0 126 | 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, # f1 127 | 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, # f2 128 | ) 129 | ``` 130 | 131 | Note: 132 | - values = [values for f0] + [values for f1] + [values for f2] 133 | - lengths are always 0 or 1, and sum(lengths) = len(values) = 2 * bsz 134 | """ 135 | lookups = torch.concat((lhs[:, None], self.relations_t[rel], rhs[:, None]), dim=1) 136 | index = torch.LongTensor([1, 0, 2, 3]) 137 | lookups = lookups[:, index].reshape(2 * self.batch_size, 2) 138 | 139 | # values is just the row indices into each table, ordered by the table indices 140 | _, indices = torch.sort(lookups[:, 0], dim=0, stable=True) 141 | values = lookups[indices][:, 1].int() 142 | 143 | # lengths[table_idx * batch_size + i] == whether the ith lookup is for the table with index table_idx 144 | lengths = torch.arange(self.num_tables)[:, None].eq(lookups[:, 0]) 145 | lengths = lengths.reshape(-1).int() 146 | 147 | return KeyedJaggedTensor(keys=self.table_names, values=values, lengths=lengths) 148 | 149 | def to_batches(self): 150 | ds = super().to_batches() 151 | batch_size = self._dataset_kwargs["batch_size"] 152 | 153 | names = [ 154 | self.lhs_column_name, 155 | self.rhs_column_name, 156 | self.rel_column_name, 157 | self.label_column_name, 158 | ] 159 | for _, batch in enumerate(ds): 160 | # Pass along positive edges 161 | lhs = batch.column(self.lhs_column_name) 162 | rhs = batch.column(self.rhs_column_name) 163 | rel = batch.column(self.rel_column_name) 164 | label = pa.array(np.ones(batch_size, dtype=np.int64)) 165 | 166 | yield pa.RecordBatch.from_arrays( 167 | arrays=[lhs, rhs, rel, label], 168 | names=names, 169 | ) 170 | -------------------------------------------------------------------------------- /projects/twhin/data/test_data.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest.mock import Mock 3 | 4 | 5 | def test_create_dataset(): 6 | pass 7 | -------------------------------------------------------------------------------- /projects/twhin/data/test_edges.py: -------------------------------------------------------------------------------- 1 | """Tests edges dataset functionality.""" 2 | 3 | from unittest.mock import patch 4 | import os 5 | import tempfile 6 | 7 | from tml.projects.twhin.data.edges import EdgesDataset 8 | from tml.projects.twhin.models.config import Relation 9 | 10 | from fsspec.implementations.local import LocalFileSystem 11 | import numpy as np 12 | import pyarrow as pa 13 | import pyarrow.compute as pc 14 | import pyarrow.parquet as pq 15 | import torch 16 | 17 | 18 | TABLE_SIZES = {"user": 16, "author": 32} 19 | RELATIONS = [ 20 | Relation(name="fav", lhs="user", rhs="author"), 21 | Relation(name="engaged_with_reply", lhs="author", rhs="user"), 22 | ] 23 | 24 | 25 | def test_gen(): 26 | import os 27 | import tempfile 28 | 29 | from fsspec.implementations.local import LocalFileSystem 30 | import pyarrow as pa 31 | import pyarrow.parquet as pq 32 | 33 | lhs = pa.array(np.arange(4)) 34 | rhs = pa.array(np.flip(np.arange(4))) 35 | rel = pa.array([0, 1, 0, 0]) 36 | names = ["lhs", "rhs", "rel"] 37 | 38 | with tempfile.TemporaryDirectory() as tmpdir: 39 | table = pa.Table.from_arrays([lhs, rhs, rel], names=names) 40 | writer = pq.ParquetWriter( 41 | os.path.join(tmpdir, "example.parquet"), 42 | table.schema, 43 | ) 44 | writer.write_table(table) 45 | writer.close() 46 | 47 | ds = EdgesDataset( 48 | file_pattern=os.path.join(tmpdir, "*"), 49 | table_sizes=TABLE_SIZES, 50 | relations=RELATIONS, 51 | batch_size=4, 52 | ) 53 | ds.FS = LocalFileSystem() 54 | 55 | dl = ds.dataloader() 56 | batch = next(iter(dl)) 57 | 58 | # labels should be positive 59 | labels = batch.labels 60 | assert (labels[:4] == 1).sum() == 4 61 | 62 | # make sure positive examples are what we expect 63 | kjt_values = batch.nodes.values() 64 | users, authors = torch.split(kjt_values, 4, dim=0) 65 | assert torch.equal(users[:4], torch.tensor([0, 2, 2, 3])) 66 | assert torch.equal(authors[:4], torch.tensor([3, 1, 1, 0])) 67 | -------------------------------------------------------------------------------- /projects/twhin/machines.yaml: -------------------------------------------------------------------------------- 1 | chief: &gpu 2 | mem: 1.4Ti 3 | cpu: 24 4 | num_accelerators: 16 5 | accelerator_type: a100 6 | dataset_dispatcher: 7 | mem: 2Gi 8 | cpu: 2 9 | num_dataset_workers: 4 10 | dataset_worker: 11 | mem: 14Gi 12 | cpu: 2 13 | -------------------------------------------------------------------------------- /projects/twhin/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchmetrics as tm 3 | 4 | import tml.core.metrics as core_metrics 5 | 6 | 7 | def create_metrics( 8 | device: torch.device, 9 | ): 10 | metrics = dict() 11 | metrics.update( 12 | { 13 | "AUC": core_metrics.Auc(128), 14 | } 15 | ) 16 | metrics = tm.MetricCollection(metrics).to(device) 17 | return metrics 18 | -------------------------------------------------------------------------------- /projects/twhin/models/config.py: -------------------------------------------------------------------------------- 1 | import typing 2 | import enum 3 | 4 | from tml.common.modules.embedding.config import LargeEmbeddingsConfig 5 | from tml.core.config import base_config 6 | from tml.optimizers.config import OptimizerConfig 7 | 8 | import pydantic 9 | from pydantic import validator 10 | 11 | 12 | class TwhinEmbeddingsConfig(LargeEmbeddingsConfig): 13 | @validator("tables") 14 | def embedding_dims_match(cls, tables): 15 | embedding_dim = tables[0].embedding_dim 16 | data_type = tables[0].data_type 17 | for table in tables: 18 | assert table.embedding_dim == embedding_dim, "Embedding dimensions for all nodes must match." 19 | assert table.data_type == data_type, "Data types for all nodes must match." 20 | return tables 21 | 22 | 23 | class Operator(str, enum.Enum): 24 | TRANSLATION = "translation" 25 | 26 | 27 | class Relation(pydantic.BaseModel): 28 | """graph relationship properties and operator""" 29 | 30 | name: str = pydantic.Field(..., description="Relationship name.") 31 | lhs: str = pydantic.Field( 32 | ..., 33 | description="Name of the entity on the left-hand-side of this relation. Must match a table name.", 34 | ) 35 | rhs: str = pydantic.Field( 36 | ..., 37 | description="Name of the entity on the right-hand-side of this relation. Must match a table name.", 38 | ) 39 | operator: Operator = pydantic.Field( 40 | Operator.TRANSLATION, description="Transformation to apply to lhs embedding before dot product." 41 | ) 42 | 43 | 44 | class TwhinModelConfig(base_config.BaseConfig): 45 | embeddings: TwhinEmbeddingsConfig 46 | relations: typing.List[Relation] 47 | translation_optimizer: OptimizerConfig 48 | 49 | @validator("relations", each_item=True) 50 | def valid_node_types(cls, relation, values, **kwargs): 51 | table_names = [table.name for table in values["embeddings"].tables] 52 | assert relation.lhs in table_names, f"Invalid lhs node type: {relation.lhs}" 53 | assert relation.rhs in table_names, f"Invalid rhs node type: {relation.rhs}" 54 | return relation 55 | -------------------------------------------------------------------------------- /projects/twhin/models/models.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | import math 3 | 4 | from tml.projects.twhin.data.edges import EdgeBatch 5 | from tml.projects.twhin.models.config import TwhinModelConfig 6 | from tml.projects.twhin.data.config import TwhinDataConfig 7 | from tml.common.modules.embedding.embedding import LargeEmbeddings 8 | from tml.optimizers.optimizer import get_optimizer_class 9 | from tml.optimizers.config import get_optimizer_algorithm_config 10 | 11 | import torch 12 | from torch import nn 13 | from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward 14 | 15 | 16 | class TwhinModel(nn.Module): 17 | def __init__(self, model_config: TwhinModelConfig, data_config: TwhinDataConfig): 18 | super().__init__() 19 | self.batch_size = data_config.per_replica_batch_size 20 | self.table_names = [table.name for table in model_config.embeddings.tables] 21 | self.large_embeddings = LargeEmbeddings(model_config.embeddings) 22 | self.embedding_dim = model_config.embeddings.tables[0].embedding_dim 23 | self.num_tables = len(model_config.embeddings.tables) 24 | self.in_batch_negatives = data_config.in_batch_negatives 25 | self.global_negatives = data_config.global_negatives 26 | self.num_relations = len(model_config.relations) 27 | 28 | # one bias per relation 29 | self.all_trans_embs = torch.nn.parameter.Parameter( 30 | torch.nn.init.uniform_(torch.empty(self.num_relations, self.embedding_dim)) 31 | ) 32 | 33 | def forward(self, batch: EdgeBatch): 34 | 35 | # B x D 36 | trans_embs = self.all_trans_embs.data[batch.rels] 37 | 38 | # KeyedTensor 39 | outs = self.large_embeddings(batch.nodes) 40 | 41 | # 2B x TD 42 | x = outs.values() 43 | 44 | # 2B x T x D 45 | x = x.reshape(2 * self.batch_size, -1, self.embedding_dim) 46 | 47 | # 2B x D 48 | x = torch.sum(x, 1) 49 | 50 | # B x 2 x D 51 | x = x.reshape(self.batch_size, 2, self.embedding_dim) 52 | 53 | # translated 54 | translated = x[:, 1, :] + trans_embs 55 | 56 | negs = [] 57 | if self.in_batch_negatives: 58 | # construct dot products for negatives via matmul 59 | for relation in range(self.num_relations): 60 | rel_mask = batch.rels == relation 61 | rel_count = rel_mask.sum() 62 | 63 | if not rel_count: 64 | continue 65 | 66 | # R x D 67 | lhs_matrix = x[rel_mask, 0, :] 68 | rhs_matrix = x[rel_mask, 1, :] 69 | 70 | lhs_perm = torch.randperm(lhs_matrix.shape[0]) 71 | # repeat until we have enough negatives 72 | lhs_perm = lhs_perm.repeat(math.ceil(float(self.in_batch_negatives) / rel_count)) 73 | lhs_indices = lhs_perm[: self.in_batch_negatives] 74 | sampled_lhs = lhs_matrix[lhs_indices] 75 | 76 | rhs_perm = torch.randperm(rhs_matrix.shape[0]) 77 | # repeat until we have enough negatives 78 | rhs_perm = rhs_perm.repeat(math.ceil(float(self.in_batch_negatives) / rel_count)) 79 | rhs_indices = rhs_perm[: self.in_batch_negatives] 80 | sampled_rhs = rhs_matrix[rhs_indices] 81 | 82 | # RS 83 | negs_rhs = torch.flatten(torch.matmul(lhs_matrix, sampled_rhs.t())) 84 | negs_lhs = torch.flatten(torch.matmul(rhs_matrix, sampled_lhs.t())) 85 | 86 | negs.append(negs_lhs) 87 | negs.append(negs_rhs) 88 | 89 | # dot product for positives 90 | x = (x[:, 0, :] * translated).sum(-1) 91 | 92 | # concat positives and negatives 93 | x = torch.cat([x, *negs]) 94 | return { 95 | "logits": x, 96 | "probabilities": torch.sigmoid(x), 97 | } 98 | 99 | 100 | def apply_optimizers(model: TwhinModel, model_config: TwhinModelConfig): 101 | for table in model_config.embeddings.tables: 102 | optimizer_class = get_optimizer_class(table.optimizer) 103 | optimizer_kwargs = get_optimizer_algorithm_config(table.optimizer).dict() 104 | params = [ 105 | param 106 | for name, param in model.large_embeddings.ebc.named_parameters() 107 | if (name.startswith(f"embedding_bags.{table.name}")) 108 | ] 109 | apply_optimizer_in_backward( 110 | optimizer_class=optimizer_class, 111 | params=params, 112 | optimizer_kwargs=optimizer_kwargs, 113 | ) 114 | 115 | return model 116 | 117 | 118 | class TwhinModelAndLoss(torch.nn.Module): 119 | def __init__( 120 | self, 121 | model, 122 | loss_fn: Callable, 123 | data_config: TwhinDataConfig, 124 | device: torch.device, 125 | ) -> None: 126 | """ 127 | Args: 128 | model: torch module to wrap. 129 | loss_fn: Function for calculating loss, should accept logits and labels. 130 | """ 131 | super().__init__() 132 | self.model = model 133 | self.loss_fn = loss_fn 134 | self.batch_size = data_config.per_replica_batch_size 135 | self.in_batch_negatives = data_config.in_batch_negatives 136 | self.device = device 137 | 138 | def forward(self, batch: "RecapBatch"): # type: ignore[name-defined] 139 | """Runs model forward and calculates loss according to given loss_fn. 140 | 141 | NOTE: The input signature here needs to be a Pipelineable object for 142 | prefetching purposes during training using torchrec's pipeline. However 143 | the underlying model signature needs to be exportable to onnx, requiring 144 | generic python types. see https://pytorch.org/docs/stable/onnx.html#types. 145 | 146 | """ 147 | outputs = self.model(batch) 148 | logits = outputs["logits"] 149 | 150 | num_negatives = 2 * self.batch_size * self.in_batch_negatives 151 | num_positives = self.batch_size 152 | 153 | neg_weight = float(num_positives) / num_negatives 154 | 155 | labels = torch.cat([batch.labels.float(), torch.ones(num_negatives).to(self.device)]) 156 | 157 | weights = torch.cat( 158 | [batch.weights.float(), (torch.ones(num_negatives) * neg_weight).to(self.device)] 159 | ) 160 | 161 | losses = self.loss_fn(logits, labels, weights) 162 | 163 | outputs.update( 164 | { 165 | "loss": losses, 166 | "labels": labels, 167 | "weights": weights, 168 | } 169 | ) 170 | 171 | # Allow multiple losses. 172 | return losses, outputs 173 | -------------------------------------------------------------------------------- /projects/twhin/models/test_models.py: -------------------------------------------------------------------------------- 1 | from tml.projects.twhin.models.config import TwhinEmbeddingsConfig, TwhinModelConfig 2 | from tml.projects.twhin.data.config import TwhinDataConfig 3 | from tml.common.modules.embedding.config import DataType, EmbeddingBagConfig 4 | from tml.optimizers.config import OptimizerConfig, SgdConfig 5 | from tml.model import maybe_shard_model 6 | from tml.projects.twhin.models.models import apply_optimizers, TwhinModel 7 | from tml.projects.twhin.models.config import Operator, Relation 8 | from tml.common.testing_utils import mock_pg 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | from pydantic import ValidationError 13 | import pytest 14 | 15 | 16 | NUM_EMBS = 10_000 17 | EMB_DIM = 128 18 | 19 | 20 | def twhin_model_config() -> TwhinModelConfig: 21 | sgd_config_0 = OptimizerConfig(sgd=SgdConfig(lr=0.01)) 22 | sgd_config_1 = OptimizerConfig(sgd=SgdConfig(lr=0.02)) 23 | 24 | table0 = EmbeddingBagConfig( 25 | name="table0", 26 | num_embeddings=NUM_EMBS, 27 | embedding_dim=EMB_DIM, 28 | optimizer=sgd_config_0, 29 | data_type=DataType.FP32, 30 | ) 31 | table1 = EmbeddingBagConfig( 32 | name="table1", 33 | num_embeddings=NUM_EMBS, 34 | embedding_dim=EMB_DIM, 35 | optimizer=sgd_config_1, 36 | data_type=DataType.FP32, 37 | ) 38 | embeddings_config = TwhinEmbeddingsConfig( 39 | tables=[table0, table1], 40 | ) 41 | 42 | model_config = TwhinModelConfig( 43 | embeddings=embeddings_config, 44 | translation_optimizer=sgd_config_0, 45 | relations=[ 46 | Relation(name="rel0", lhs="table0", rhs="table1", operator=Operator.TRANSLATION), 47 | Relation(name="rel1", lhs="table1", rhs="table0", operator=Operator.TRANSLATION), 48 | ], 49 | ) 50 | 51 | return model_config 52 | 53 | 54 | def twhin_data_config() -> TwhinDataConfig: 55 | data_config = TwhinDataConfig( 56 | data_root="/", 57 | per_replica_batch_size=10, 58 | global_negatives=10, 59 | in_batch_negatives=10, 60 | limit=1, 61 | offset=1, 62 | ) 63 | 64 | return data_config 65 | 66 | 67 | def test_twhin_model(): 68 | model_config = twhin_model_config() 69 | loss_fn = F.binary_cross_entropy_with_logits 70 | 71 | with mock_pg(): 72 | data_config = twhin_data_config() 73 | model = TwhinModel(model_config=model_config, data_config=data_config) 74 | 75 | apply_optimizers(model, model_config) 76 | 77 | for tensor in model.state_dict().values(): 78 | if tensor.size() == (NUM_EMBS, EMB_DIM): 79 | assert str(tensor.device) == "meta" 80 | else: 81 | assert str(tensor.device) == "cpu" 82 | 83 | model = maybe_shard_model(model, device=torch.device("cpu")) 84 | 85 | 86 | def test_unequal_dims(): 87 | sgd_config_1 = OptimizerConfig(sgd=SgdConfig(lr=0.02)) 88 | sgd_config_2 = OptimizerConfig(sgd=SgdConfig(lr=0.05)) 89 | table0 = EmbeddingBagConfig( 90 | name="table0", 91 | num_embeddings=10_000, 92 | embedding_dim=128, 93 | optimizer=sgd_config_1, 94 | data_type=DataType.FP32, 95 | ) 96 | table1 = EmbeddingBagConfig( 97 | name="table1", 98 | num_embeddings=10_000, 99 | embedding_dim=64, 100 | optimizer=sgd_config_2, 101 | data_type=DataType.FP32, 102 | ) 103 | 104 | with pytest.raises(ValidationError): 105 | _ = TwhinEmbeddingsConfig( 106 | tables=[table0, table1], 107 | ) 108 | -------------------------------------------------------------------------------- /projects/twhin/optimizer.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | from tml.projects.twhin.models.config import TwhinModelConfig 4 | from tml.projects.twhin.models.models import TwhinModel 5 | from tml.optimizers.optimizer import get_optimizer_class, LRShim 6 | from tml.optimizers.config import get_optimizer_algorithm_config, LearningRate 7 | from tml.ml_logging.torch_logging import logging 8 | 9 | from torchrec.optim.optimizers import in_backward_optimizer_filter 10 | from torchrec.optim import keyed 11 | 12 | 13 | FUSED_OPT_KEY = "fused_opt" 14 | TRANSLATION_OPT_KEY = "operator_opt" 15 | 16 | 17 | def _lr_from_config(optimizer_config): 18 | if optimizer_config.learning_rate is not None: 19 | return optimizer_config.learning_rate 20 | else: 21 | # treat None as constant lr 22 | lr_value = get_optimizer_algorithm_config(optimizer_config).lr 23 | return LearningRate(constant=lr_value) 24 | 25 | 26 | def build_optimizer(model: TwhinModel, config: TwhinModelConfig): 27 | """Builds an optimizer for a Twhin model combining the embeddings optimizer with an optimizer for per-relation translations. 28 | 29 | Args: 30 | model: TwhinModel to build optimizer for. 31 | config: TwhinConfig for model. 32 | 33 | Returns: 34 | Optimizer for model. 35 | """ 36 | translation_optimizer_fn = functools.partial( 37 | get_optimizer_class(config.translation_optimizer), 38 | **get_optimizer_algorithm_config(config.translation_optimizer).dict(), 39 | ) 40 | 41 | translation_optimizer = keyed.KeyedOptimizerWrapper( 42 | dict(in_backward_optimizer_filter(model.named_parameters())), 43 | optim_factory=translation_optimizer_fn, 44 | ) 45 | 46 | lr_dict = {} 47 | for table in config.embeddings.tables: 48 | lr_dict[table.name] = _lr_from_config(table.optimizer) 49 | lr_dict[TRANSLATION_OPT_KEY] = _lr_from_config(config.translation_optimizer) 50 | 51 | logging.info(f"***** LR dict: {lr_dict} *****") 52 | 53 | logging.info( 54 | f"***** Combining fused optimizer {model.fused_optimizer} with operator optimizer: {translation_optimizer} *****" 55 | ) 56 | optimizer = keyed.CombinedOptimizer( 57 | [ 58 | (FUSED_OPT_KEY, model.fused_optimizer), 59 | (TRANSLATION_OPT_KEY, translation_optimizer), 60 | ] 61 | ) 62 | 63 | # scheduler = LRShim(optimizer, lr_dict) 64 | scheduler = None 65 | 66 | logging.info(f"***** Combined optimizer after init: {optimizer} *****") 67 | 68 | return optimizer, scheduler 69 | -------------------------------------------------------------------------------- /projects/twhin/run.py: -------------------------------------------------------------------------------- 1 | from absl import app, flags 2 | import json 3 | from typing import Optional 4 | import os 5 | import sys 6 | 7 | import torch 8 | 9 | # isort: on 10 | from tml.common.device import setup_and_get_device 11 | from tml.common.utils import setup_configuration 12 | import tml.core.custom_training_loop as ctl 13 | import tml.machines.environment as env 14 | from tml.projects.twhin.models.models import apply_optimizers, TwhinModel, TwhinModelAndLoss 15 | from tml.model import maybe_shard_model 16 | from tml.projects.twhin.metrics import create_metrics 17 | from tml.projects.twhin.config import TwhinConfig 18 | from tml.projects.twhin.data.data import create_dataset 19 | from tml.projects.twhin.optimizer import build_optimizer 20 | 21 | from tml.ml_logging.torch_logging import logging 22 | 23 | import torch.distributed as dist 24 | from torch.nn import functional as F 25 | from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward 26 | from torchrec.distributed.model_parallel import get_module 27 | 28 | FLAGS = flags.FLAGS 29 | 30 | flags.DEFINE_bool("overwrite_save_dir", False, "Whether to clear preexisting save directories.") 31 | flags.DEFINE_string("save_dir", None, "If provided, overwrites the save directory.") 32 | flags.DEFINE_string("config_yaml_path", None, "Path to hyperparameters for model.") 33 | flags.DEFINE_string("task", None, "Task to run if this is local. Overrides TF_CONFIG etc.") 34 | 35 | 36 | def run( 37 | all_config: TwhinConfig, 38 | save_dir: Optional[str] = None, 39 | ): 40 | train_dataset = create_dataset(all_config.train_data, all_config.model) 41 | 42 | if env.is_reader(): 43 | train_dataset.serve() 44 | if env.is_chief(): 45 | device = setup_and_get_device(tf_ok=False) 46 | logging.info(f"device: {device}") 47 | logging.info(f"WORLD_SIZE: {dist.get_world_size()}") 48 | 49 | # validation_dataset = create_dataset(all_config.validation_data, all_config.model) 50 | 51 | global_batch_size = all_config.train_data.per_replica_batch_size * dist.get_world_size() 52 | 53 | metrics = create_metrics(device) 54 | 55 | model = TwhinModel(all_config.model, all_config.train_data) 56 | apply_optimizers(model, all_config.model) 57 | model = maybe_shard_model(model, device=device) 58 | optimizer, scheduler = build_optimizer(model=model, config=all_config.model) 59 | 60 | loss_fn = F.binary_cross_entropy_with_logits 61 | model_and_loss = TwhinModelAndLoss( 62 | model, loss_fn, data_config=all_config.train_data, device=device 63 | ) 64 | 65 | ctl.train( 66 | model=model_and_loss, 67 | optimizer=optimizer, 68 | device=device, 69 | save_dir=save_dir, 70 | logging_interval=all_config.training.train_log_every_n, 71 | train_steps=all_config.training.num_train_steps, 72 | checkpoint_frequency=all_config.training.checkpoint_every_n, 73 | dataset=train_dataset.dataloader(remote=False), 74 | worker_batch_size=global_batch_size, 75 | num_workers=0, 76 | scheduler=scheduler, 77 | initial_checkpoint_dir=all_config.training.initial_checkpoint_dir, 78 | gradient_accumulation=all_config.training.gradient_accumulation, 79 | ) 80 | 81 | 82 | def main(argv): 83 | logging.info("Starting") 84 | 85 | logging.info(f"parsing config from {FLAGS.config_yaml_path}...") 86 | all_config = setup_configuration( # type: ignore[var-annotated] 87 | TwhinConfig, 88 | yaml_path=FLAGS.config_yaml_path, 89 | ) 90 | 91 | run( 92 | all_config, 93 | save_dir=FLAGS.save_dir, 94 | ) 95 | 96 | 97 | if __name__ == "__main__": 98 | app.run(main) 99 | -------------------------------------------------------------------------------- /projects/twhin/scripts/docker_run.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | 3 | docker run -it --rm \ 4 | -v $HOME/workspace/tml:/usr/src/app/tml \ 5 | -v $HOME/.config:/root/.config \ 6 | -w /usr/src/app \ 7 | -e PYTHONPATH="/usr/src/app/" \ 8 | --network host \ 9 | -e SPEC_TYPE=chief \ 10 | local/torch \ 11 | bash tml/projects/twhin/scripts/run_in_docker.sh 12 | -------------------------------------------------------------------------------- /projects/twhin/scripts/run_in_docker.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | 3 | torchrun \ 4 | --standalone \ 5 | --nnodes 1 \ 6 | --nproc_per_node 2 \ 7 | /usr/src/app/tml/projects/twhin/run.py \ 8 | --config_yaml_path="/usr/src/app/tml/projects/twhin/config/local.yaml" \ 9 | --save_dir="/some/save/dir" 10 | -------------------------------------------------------------------------------- /projects/twhin/test_optimizer.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import unittest 3 | 4 | from tml.projects.twhin.models.models import TwhinModel, apply_optimizers 5 | from tml.projects.twhin.models.test_models import twhin_model_config, twhin_data_config 6 | from tml.projects.twhin.optimizer import build_optimizer 7 | from tml.model import maybe_shard_model 8 | from tml.common.testing_utils import mock_pg 9 | 10 | 11 | import torch 12 | from torch.nn import functional as F 13 | 14 | 15 | def test_twhin_optimizer(): 16 | model_config = twhin_model_config() 17 | data_config = twhin_data_config() 18 | 19 | loss_fn = F.binary_cross_entropy_with_logits 20 | with mock_pg(): 21 | model = TwhinModel(model_config, data_config) 22 | apply_optimizers(model, model_config) 23 | model = maybe_shard_model(model, device=torch.device("cpu")) 24 | 25 | optimizer, _ = build_optimizer(model, model_config) 26 | 27 | # make sure there is one combined fused optimizer and one translation optimizer 28 | assert len(optimizer.optimizers) == 2 29 | fused_opt_tup, _ = optimizer.optimizers 30 | _, fused_opt = fused_opt_tup 31 | 32 | # make sure there are two tables for which the fused opt has parameters 33 | assert len(fused_opt.param_groups) == 2 34 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 100 3 | include = '\.pyi?$' 4 | exclude = ''' 5 | /( 6 | \.git 7 | | \.hg 8 | | \.pem 9 | | \.mypy_cache 10 | | \.tox 11 | | \.venv 12 | | _build 13 | | buck-out 14 | | build 15 | | dist 16 | )/ 17 | ''' 18 | -------------------------------------------------------------------------------- /reader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twitter/the-algorithm-ml/b85210863f7a94efded0ef5c5ccf4ff42767876c/reader/__init__.py -------------------------------------------------------------------------------- /reader/dataset.py: -------------------------------------------------------------------------------- 1 | """Dataset to be overwritten that can work with or without distributed reading. 2 | 3 | - Override `pa_to_batch` for dataset specific imputation, negative sampling, or coercion to Batch. 4 | - Readers can be colocated or off trainer machines. 5 | 6 | """ 7 | import abc 8 | import functools 9 | import random 10 | from typing import Optional 11 | 12 | from fsspec.implementations.local import LocalFileSystem 13 | import pyarrow.dataset as pads 14 | import pyarrow as pa 15 | import pyarrow.parquet 16 | import pyarrow.flight 17 | from pyarrow.ipc import IpcWriteOptions 18 | import torch 19 | 20 | from tml.common.batch import DataclassBatch 21 | from tml.machines import environment as env 22 | import tml.reader.utils as reader_utils 23 | from tml.common.filesystem import infer_fs 24 | from tml.ml_logging.torch_logging import logging 25 | 26 | 27 | class _Reader(pa.flight.FlightServerBase): 28 | """Distributed reader flight server wrapping a dataset.""" 29 | 30 | def __init__(self, location: str, ds: "Dataset"): 31 | super().__init__(location=location) 32 | self._location = location 33 | self._ds = ds 34 | 35 | def do_get(self, _, __): 36 | # NB: An updated schema (to account for column selection) has to be given the stream. 37 | schema = next(iter(self._ds.to_batches())).schema 38 | batches = self._ds.to_batches() 39 | return pa.flight.RecordBatchStream( 40 | data_source=pa.RecordBatchReader.from_batches( 41 | schema=schema, 42 | batches=batches, 43 | ), 44 | options=IpcWriteOptions(use_threads=True), 45 | ) 46 | 47 | 48 | class Dataset(torch.utils.data.IterableDataset): 49 | LOCATION = "grpc://0.0.0.0:2222" 50 | 51 | def __init__(self, file_pattern: str, **dataset_kwargs) -> None: 52 | """Specify batch size and column to select for. 53 | 54 | Refer to https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner.from_dataset. 55 | """ 56 | self._file_pattern = file_pattern 57 | self._fs = infer_fs(self._file_pattern) 58 | self._dataset_kwargs = dataset_kwargs 59 | logging.info(f"Using dataset_kwargs: {self._dataset_kwargs}") 60 | self._files = self._fs.glob(self._file_pattern) 61 | assert len(self._files) > 0, f"No files found at {self._file_pattern}" 62 | logging.info(f"Found {len(self._files)} files: {', '.join(self._files[:4])}, ...") 63 | self._schema = pa.parquet.read_schema(self._files[0], filesystem=self._fs) 64 | self._validate_columns() 65 | 66 | def _validate_columns(self): 67 | columns = set(self._dataset_kwargs.get("columns", [])) 68 | wrong_columns = set(columns) - set(self._schema.names) 69 | if wrong_columns: 70 | raise Exception(f"Specified columns {list(wrong_columns)} not in schema.") 71 | 72 | def serve(self): 73 | self.reader = _Reader(location=self.LOCATION, ds=self) 74 | self.reader.serve() 75 | 76 | def _create_dataset(self): 77 | return pads.dataset( 78 | source=random.sample(self._files, len(self._files))[0], 79 | format="parquet", 80 | filesystem=self._fs, 81 | exclude_invalid_files=False, 82 | ) 83 | 84 | def to_batches(self): 85 | """This allows the init to control reading settings. 86 | 87 | Refer to https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner.from_dataset. 88 | 89 | Perform `drop_remainder` behavior to afix batch size. 90 | This does not shift our data distribution bc of volume and file-level shuffling on every repeat. 91 | """ 92 | batch_size = self._dataset_kwargs["batch_size"] 93 | while True: 94 | ds = self._create_dataset() 95 | for batch in ds.to_batches(**self._dataset_kwargs): 96 | if batch.num_rows < batch_size: 97 | logging.info(f"Dropping remainder ({batch.num_rows}/{batch_size})") 98 | break 99 | yield batch 100 | 101 | @abc.abstractmethod 102 | def pa_to_batch(self, batch: pa.RecordBatch) -> DataclassBatch: 103 | raise NotImplementedError 104 | 105 | def dataloader(self, remote: bool = False): 106 | if not remote: 107 | return map(self.pa_to_batch, self.to_batches()) 108 | readers = get_readers(2) 109 | return map(self.pa_to_batch, reader_utils.roundrobin(*readers)) 110 | 111 | 112 | GRPC_OPTIONS = [ 113 | ("GRPC_ARG_KEEPALIVE_TIME_MS", 60000), 114 | ("GRPC_ARG_MIN_RECONNECT_BACKOFF_MS", 2000), 115 | ("GRPC_ARG_MAX_METADATA_SIZE", 1024 * 1024 * 1024), 116 | ] 117 | 118 | 119 | def get_readers(num_readers_per_worker: int): 120 | addresses = env.get_flight_server_addresses() 121 | 122 | readers = [] 123 | for worker in addresses: 124 | logging.info(f"Attempting connection to reader {worker}.") 125 | client = pa.flight.connect(worker, generic_options=GRPC_OPTIONS) 126 | client.wait_for_available(60) 127 | reader = client.do_get(None).to_reader() 128 | logging.info(f"Connected reader to {worker}.") 129 | readers.append(reader) 130 | return readers 131 | -------------------------------------------------------------------------------- /reader/dds.py: -------------------------------------------------------------------------------- 1 | """Dataset service orchestrated by a TFJob 2 | """ 3 | from typing import Optional 4 | import uuid 5 | 6 | from tml.ml_logging.torch_logging import logging 7 | import tml.machines.environment as env 8 | 9 | import packaging.version 10 | import tensorflow as tf 11 | 12 | try: 13 | import tensorflow_io as tfio 14 | except: 15 | pass 16 | from tensorflow.python.data.experimental.ops.data_service_ops import ( 17 | _from_dataset_id, 18 | _register_dataset, 19 | ) 20 | import torch.distributed as dist 21 | 22 | 23 | def maybe_start_dataset_service(): 24 | if not env.has_readers(): 25 | return 26 | 27 | if packaging.version.parse(tf.__version__) < packaging.version.parse("2.5"): 28 | raise Exception(f"maybe_distribute_dataset requires TF >= 2.5; got {tf.__version__}") 29 | 30 | if env.is_dispatcher(): 31 | logging.info(f"env.get_reader_port() = {env.get_reader_port()}") 32 | logging.info(f"env.get_dds_journaling_dir() = {env.get_dds_journaling_dir()}") 33 | work_dir = env.get_dds_journaling_dir() 34 | server = tf.data.experimental.service.DispatchServer( 35 | tf.data.experimental.service.DispatcherConfig( 36 | port=env.get_reader_port(), 37 | protocol="grpc", 38 | work_dir=work_dir, 39 | fault_tolerant_mode=bool(work_dir), 40 | ) 41 | ) 42 | server.join() 43 | 44 | elif env.is_reader(): 45 | logging.info(f"env.get_reader_port() = {env.get_reader_port()}") 46 | logging.info(f"env.get_dds_dispatcher_address() = {env.get_dds_dispatcher_address()}") 47 | logging.info(f"env.get_dds_worker_address() = {env.get_dds_worker_address()}") 48 | server = tf.data.experimental.service.WorkerServer( 49 | tf.data.experimental.service.WorkerConfig( 50 | port=env.get_reader_port(), 51 | dispatcher_address=env.get_dds_dispatcher_address(), 52 | worker_address=env.get_dds_worker_address(), 53 | protocol="grpc", 54 | ) 55 | ) 56 | server.join() 57 | 58 | 59 | def register_dataset( 60 | dataset: tf.data.Dataset, dataset_service: str, compression: Optional[str] = "AUTO" 61 | ): 62 | if dist.get_rank() == 0: 63 | dataset_id = _register_dataset( 64 | service=dataset_service, 65 | dataset=dataset, 66 | compression=compression, 67 | ) 68 | job_name = uuid.uuid4().hex[:8] 69 | id_and_job = [dataset_id.numpy(), job_name] 70 | logging.info(f"rank{dist.get_rank()}: Created dds job with {dataset_id.numpy()}, {job_name}") 71 | else: 72 | id_and_job = [None, None] 73 | 74 | dist.broadcast_object_list(id_and_job, src=0) 75 | return tuple(id_and_job) 76 | 77 | 78 | def distribute_from_dataset_id( 79 | dataset_service: str, 80 | dataset_id: int, 81 | job_name: Optional[str], 82 | compression: Optional[str] = "AUTO", 83 | prefetch: Optional[int] = tf.data.experimental.AUTOTUNE, 84 | ) -> tf.data.Dataset: 85 | logging.info(f"rank{dist.get_rank()}: Consuming dds job with {dataset_id}, {job_name}") 86 | dataset = _from_dataset_id( 87 | processing_mode="parallel_epochs", 88 | service=dataset_service, 89 | dataset_id=dataset_id, 90 | job_name=job_name, 91 | element_spec=None, 92 | compression=compression, 93 | ) 94 | if prefetch is not None: 95 | dataset = dataset.prefetch(prefetch) 96 | return dataset 97 | 98 | 99 | def maybe_distribute_dataset(dataset: tf.data.Dataset) -> tf.data.Dataset: 100 | """Torch-compatible and distributed-training-aware dataset service distributor. 101 | 102 | - rank 0 process will register the given dataset. 103 | - rank 0 process will broadcast job name and dataset id. 104 | - all rank processes will consume from the same job/dataset. 105 | 106 | Without this, dataset workers will try to serve 1 job per rank process and OOM. 107 | 108 | """ 109 | if not env.has_readers(): 110 | return dataset 111 | dataset_service = env.get_dds() 112 | 113 | logging.info(f"using DDS = {dataset_service}") 114 | dataset_id, job_name = register_dataset(dataset=dataset, dataset_service=dataset_service) 115 | dataset = distribute_from_dataset_id( 116 | dataset_service=dataset_service, dataset_id=dataset_id, job_name=job_name 117 | ) 118 | return dataset 119 | 120 | 121 | if __name__ == "__main__": 122 | maybe_start_dataset_service() 123 | -------------------------------------------------------------------------------- /reader/test_dataset.py: -------------------------------------------------------------------------------- 1 | import multiprocessing as mp 2 | import os 3 | from unittest.mock import patch 4 | 5 | import tml.reader.utils as reader_utils 6 | from tml.reader.dataset import Dataset 7 | 8 | import pyarrow as pa 9 | import pyarrow.parquet as pq 10 | import pytest 11 | import torch 12 | 13 | 14 | def create_dataset(tmpdir): 15 | 16 | table = pa.table( 17 | { 18 | "year": [2020, 2022, 2021, 2022, 2019, 2021], 19 | "n_legs": [2, 2, 4, 4, 5, 100], 20 | } 21 | ) 22 | file_path = tmpdir 23 | pq.write_to_dataset(table, root_path=str(file_path)) 24 | 25 | class MockDataset(Dataset): 26 | def __init__(self, *args, **kwargs): 27 | super().__init__(*args, **kwargs) 28 | self._pa_to_batch = reader_utils.create_default_pa_to_batch(self._schema) 29 | 30 | def pa_to_batch(self, batch): 31 | return self._pa_to_batch(batch) 32 | 33 | return MockDataset(file_pattern=str(file_path / "*"), batch_size=2) 34 | 35 | 36 | def test_dataset(tmpdir): 37 | ds = create_dataset(tmpdir) 38 | batch = next(iter(ds.dataloader(remote=False))) 39 | assert batch.batch_size == 2 40 | assert torch.equal(batch.year, torch.Tensor([2020, 2022])) 41 | assert torch.equal(batch.n_legs, torch.Tensor([2, 2])) 42 | 43 | 44 | @pytest.mark.skipif( 45 | os.environ.get("GITHUB_WORKSPACE") is not None, 46 | reason="Multiprocessing doesn't work on github yet.", 47 | ) 48 | def test_distributed_dataset(tmpdir): 49 | MOCK_ENV = {"TEMP_SLURM_NUM_READERS": "1"} 50 | 51 | def _client(): 52 | with patch.dict(os.environ, MOCK_ENV): 53 | with patch( 54 | "tml.reader.dataset.env.get_flight_server_addresses", return_value=["grpc://localhost:2222"] 55 | ): 56 | ds = create_dataset(tmpdir) 57 | batch = next(iter(ds.dataloader(remote=True))) 58 | assert batch.batch_size == 2 59 | assert torch.equal(batch.year, torch.Tensor([2020, 2022])) 60 | assert torch.equal(batch.n_legs, torch.Tensor([2, 2])) 61 | 62 | def _worker(): 63 | ds = create_dataset(tmpdir) 64 | ds.serve() 65 | 66 | worker = mp.Process(target=_worker) 67 | client = mp.Process(target=_client) 68 | worker.start() 69 | client.start() 70 | client.join() 71 | assert not client.exitcode 72 | worker.kill() 73 | client.kill() 74 | -------------------------------------------------------------------------------- /reader/test_utils.py: -------------------------------------------------------------------------------- 1 | import tml.reader.utils as reader_utils 2 | 3 | 4 | def test_rr(): 5 | options = ["a", "b", "c"] 6 | rr = reader_utils.roundrobin(options) 7 | for i, v in enumerate(rr): 8 | assert v == options[i % 3] 9 | if i > 4: 10 | break 11 | -------------------------------------------------------------------------------- /reader/utils.py: -------------------------------------------------------------------------------- 1 | """Reader utilities.""" 2 | import itertools 3 | import time 4 | from typing import Optional 5 | 6 | from tml.common.batch import DataclassBatch 7 | from tml.ml_logging.torch_logging import logging 8 | 9 | import pyarrow as pa 10 | import torch 11 | 12 | 13 | def roundrobin(*iterables): 14 | """Round robin through provided iterables, useful for simple load balancing. 15 | 16 | Adapted from https://docs.python.org/3/library/itertools.html. 17 | 18 | """ 19 | num_active = len(iterables) 20 | nexts = itertools.cycle(iter(it).__next__ for it in iterables) 21 | while num_active: 22 | try: 23 | for _next in nexts: 24 | result = _next() 25 | yield result 26 | except StopIteration: 27 | # Remove the iterator we just exhausted from the cycle. 28 | num_active -= 1 29 | nexts = itertools.cycle(itertools.islice(nexts, num_active)) 30 | logging.warning(f"Iterable exhausted, {num_active} iterables left.") 31 | except Exception as exc: 32 | logging.warning(f"Iterable raised exception {exc}, ignoring.") 33 | # continue 34 | raise 35 | 36 | 37 | def speed_check(data_loader, max_steps: int, frequency: int, peek: Optional[int]): 38 | num_examples = 0 39 | prev = time.perf_counter() 40 | for idx, batch in enumerate(data_loader): 41 | if idx > max_steps: 42 | break 43 | if peek and idx % peek == 0: 44 | logging.info(f"Batch: {batch}") 45 | num_examples += batch.batch_size 46 | if idx % frequency == 0: 47 | now = time.perf_counter() 48 | elapsed = now - prev 49 | logging.info( 50 | f"step: {idx}, " 51 | f"elapsed(s): {elapsed}, " 52 | f"examples: {num_examples}, " 53 | f"ex/s: {num_examples / elapsed}, " 54 | ) 55 | prev = now 56 | num_examples = 0 57 | 58 | 59 | def pa_to_torch(array: pa.array) -> torch.Tensor: 60 | return torch.from_numpy(array.to_numpy()) 61 | 62 | 63 | def create_default_pa_to_batch(schema) -> DataclassBatch: 64 | """ """ 65 | _CustomBatch = DataclassBatch.from_schema("DefaultBatch", schema=schema) 66 | 67 | def get_imputation_value(pa_type): 68 | type_map = { 69 | pa.float64(): pa.scalar(0, type=pa.float64()), 70 | pa.int64(): pa.scalar(0, type=pa.int64()), 71 | pa.string(): pa.scalar("", type=pa.string()), 72 | } 73 | if pa_type not in type_map: 74 | raise Exception(f"Imputation for type {pa_type} not supported.") 75 | return type_map[pa_type] 76 | 77 | def _impute(array: pa.array) -> pa.array: 78 | return array.fill_null(get_imputation_value(array.type)) 79 | 80 | def _column_to_tensor(record_batch: pa.RecordBatch): 81 | tensors = { 82 | col_name: pa_to_torch(_impute(record_batch.column(col_name))) 83 | for col_name in record_batch.schema.names 84 | } 85 | return _CustomBatch(**tensors) 86 | 87 | return _column_to_tensor 88 | -------------------------------------------------------------------------------- /tools/pq.py: -------------------------------------------------------------------------------- 1 | """Local reader of parquet files. 2 | 3 | 1. Make sure you are initialized locally: 4 | ``` 5 | ./images/init_venv_macos.sh 6 | ``` 7 | 2. Activate 8 | ``` 9 | source ~/tml_venv/bin/activate 10 | ``` 11 | 3. Use tool, e.g. 12 | 13 | `head` prints the first `--num` rows of the dataset. 14 | ``` 15 | python3 tools/pq.py \ 16 | --num 5 --path "tweet_eng/small/edges/all/*" \ 17 | head 18 | ``` 19 | 20 | `distinct` prints the observed values in the first `--num` rows for the specified columns. 21 | ``` 22 | python3 tools/pq.py \ 23 | --num 1000000000 --columns '["rel"]' \ 24 | --path "tweet_eng/small/edges/all/*" \ 25 | distinct 26 | ``` 27 | 28 | """ 29 | from typing import List, Optional 30 | 31 | from tml.common.filesystem import infer_fs 32 | 33 | import fire 34 | import pandas as pd 35 | import pyarrow as pa 36 | import pyarrow.dataset as pads 37 | import pyarrow.parquet as pq 38 | 39 | 40 | def _create_dataset(path: str): 41 | fs = infer_fs(path) 42 | files = fs.glob(path) 43 | return pads.dataset(files, format="parquet", filesystem=fs) 44 | 45 | 46 | class PqReader: 47 | def __init__( 48 | self, path: str, num: int = 10, batch_size: int = 1024, columns: Optional[List[str]] = None 49 | ): 50 | self._ds = _create_dataset(path) 51 | self._batch_size = batch_size 52 | self._num = num 53 | self._columns = columns 54 | 55 | def __iter__(self): 56 | batches = self._ds.to_batches(batch_size=self._batch_size, columns=self._columns) 57 | rows_seen = 0 58 | for count, record in enumerate(batches): 59 | if self._num and rows_seen >= self._num: 60 | break 61 | yield record 62 | rows_seen += record.data.num_rows 63 | 64 | def _head(self): 65 | total_read = self._num * self.bytes_per_row 66 | if total_read >= int(500e6): 67 | raise Exception( 68 | "Sorry you're trying to read more than 500 MB " f"into memory ({total_read} bytes)." 69 | ) 70 | return self._ds.head(self._num, columns=self._columns) 71 | 72 | @property 73 | def bytes_per_row(self) -> int: 74 | nbits = 0 75 | for t in self._ds.schema.types: 76 | try: 77 | nbits += t.bit_width 78 | except: 79 | # Just estimate size if it is variable 80 | nbits += 8 81 | return nbits // 8 82 | 83 | def schema(self): 84 | print(f"\n# Schema\n{self._ds.schema}") 85 | 86 | def head(self): 87 | """Displays first --num rows.""" 88 | print(self._head().to_pandas()) 89 | 90 | def distinct(self): 91 | """Displays unique values seen in specified columns in the first `--num` rows. 92 | 93 | Useful for getting an approximate vocabulary for certain columns. 94 | 95 | """ 96 | for col_name, column in zip(self._head().column_names, self._head().columns): 97 | print(col_name) 98 | print("unique:", column.unique().to_pylist()) 99 | 100 | 101 | if __name__ == "__main__": 102 | pd.set_option("display.max_columns", None) 103 | pd.set_option("display.max_rows", None) 104 | fire.Fire(PqReader) 105 | --------------------------------------------------------------------------------