├── common ├── __init__.py ├── checkpointing │ └── __init__.py ├── filesystem │ ├── __init__.py │ ├── test_infer_fs.py │ └── util.py ├── test_device.py ├── device.py ├── testing_utils.py ├── wandb.py ├── utils.py ├── modules │ └── embedding │ │ ├── embedding.py │ │ └── config.py ├── batch.py ├── log_weights.py └── run_training.py ├── core ├── __init__.py ├── loss_type.py ├── config │ ├── __init__.py │ ├── test_config_load.py │ ├── config_load.py │ ├── base_config_test.py │ ├── training.py │ └── base_config.py ├── debug_training_loop.py ├── test_train_pipeline.py ├── losses.py ├── metric_mixin.py ├── test_metrics.py └── metrics.py ├── reader ├── __init__.py ├── test_utils.py ├── test_dataset.py ├── utils.py ├── dds.py └── dataset.py ├── ml_logging ├── __init__.py ├── test_torch_logging.py ├── absl_logging.py └── torch_logging.py ├── projects ├── __init__.py ├── home │ └── recap │ │ ├── __init__.py │ │ ├── data │ │ ├── __init__.py │ │ ├── generate_random_data.py │ │ ├── util.py │ │ ├── tfe_parsing.py │ │ ├── preprocessors.py │ │ └── config.py │ │ ├── optimizer │ │ ├── __init__.py │ │ ├── config.py │ │ └── optimizer.py │ │ ├── model │ │ ├── __init__.py │ │ ├── numeric_calibration.py │ │ ├── mlp.py │ │ ├── model_and_loss.py │ │ ├── mask_net.py │ │ └── feature_transform.py │ │ ├── script │ │ ├── create_random_data.sh │ │ └── run_local.sh │ │ ├── config.py │ │ ├── config │ │ └── home_recap_2022 │ │ │ └── segdense.json │ │ ├── README.md │ │ ├── main.py │ │ └── embedding │ │ └── config.py └── twhin │ ├── data │ ├── test_data.py │ ├── config.py │ ├── data.py │ ├── test_edges.py │ └── edges.py │ ├── machines.yaml │ ├── scripts │ ├── run_in_docker.sh │ └── docker_run.sh │ ├── metrics.py │ ├── config.py │ ├── test_optimizer.py │ ├── README.md │ ├── config │ └── local.yaml │ ├── models │ ├── config.py │ ├── test_models.py │ └── models.py │ ├── optimizer.py │ └── run.py ├── optimizers ├── __init__.py ├── config.py └── optimizer.py ├── metrics ├── __init__.py ├── aggregation.py └── auroc.py ├── pyproject.toml ├── .gitignore ├── .pre-commit-config.yaml ├── README.md ├── machines ├── is_venv.py ├── list_ops.py ├── get_env.py └── environment.py ├── images ├── init_venv.sh └── requirements.txt ├── .github └── workflows │ └── main.yml ├── LICENSE.torchrec ├── tools └── pq.py └── model.py /common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /reader/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ml_logging/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /projects/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /projects/home/recap/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /projects/home/recap/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from tml.optimizers.optimizer import compute_lr 2 | -------------------------------------------------------------------------------- /common/checkpointing/__init__.py: -------------------------------------------------------------------------------- 1 | from tml.common.checkpointing.snapshot import get_checkpoint, Snapshot 2 | -------------------------------------------------------------------------------- /common/filesystem/__init__.py: -------------------------------------------------------------------------------- 1 | from tml.common.filesystem.util import infer_fs, is_gcs_fs, is_local_fs 2 | -------------------------------------------------------------------------------- /projects/home/recap/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | from tml.projects.home.recap.optimizer.optimizer import build_optimizer 2 | -------------------------------------------------------------------------------- /projects/twhin/data/test_data.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest.mock import Mock 3 | 4 | 5 | def test_create_dataset(): 6 | pass 7 | -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .aggregation import StableMean # noqa 2 | from .auroc import AUROCWithMWU # noqa 3 | from .rce import NRCE, RCE # noqa 4 | -------------------------------------------------------------------------------- /core/loss_type.py: -------------------------------------------------------------------------------- 1 | """Loss type enums.""" 2 | from enum import Enum 3 | 4 | 5 | class LossType(str, Enum): 6 | CROSS_ENTROPY = "cross_entropy" 7 | BCE_WITH_LOGITS = "bce_with_logits" 8 | -------------------------------------------------------------------------------- /projects/home/recap/model/__init__.py: -------------------------------------------------------------------------------- 1 | from tml.projects.home.recap.model.entrypoint import ( 2 | create_ranking_model, 3 | sanitize, 4 | unsanitize, 5 | MultiTaskRankingModel, 6 | ) 7 | from tml.projects.home.recap.model.model_and_loss import ModelAndLoss 8 | -------------------------------------------------------------------------------- /projects/twhin/machines.yaml: -------------------------------------------------------------------------------- 1 | chief: &gpu 2 | mem: 1.4Ti 3 | cpu: 24 4 | num_accelerators: 16 5 | accelerator_type: a100 6 | dataset_dispatcher: 7 | mem: 2Gi 8 | cpu: 2 9 | num_dataset_workers: 4 10 | dataset_worker: 11 | mem: 14Gi 12 | cpu: 2 13 | -------------------------------------------------------------------------------- /reader/test_utils.py: -------------------------------------------------------------------------------- 1 | import tml.reader.utils as reader_utils 2 | 3 | 4 | def test_rr(): 5 | options = ["a", "b", "c"] 6 | rr = reader_utils.roundrobin(options) 7 | for i, v in enumerate(rr): 8 | assert v == options[i % 3] 9 | if i > 4: 10 | break 11 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 100 3 | include = '\.pyi?$' 4 | exclude = ''' 5 | /( 6 | \.git 7 | | \.hg 8 | | \.pem 9 | | \.mypy_cache 10 | | \.tox 11 | | \.venv 12 | | _build 13 | | buck-out 14 | | build 15 | | dist 16 | )/ 17 | ''' 18 | -------------------------------------------------------------------------------- /core/config/__init__.py: -------------------------------------------------------------------------------- 1 | from tml.core.config.base_config import BaseConfig 2 | from tml.core.config.config_load import load_config_from_yaml 3 | 4 | # Make mypy happy by explicitly rexporting the symbols intended for end user use. 5 | __all__ = ["BaseConfig", "load_config_from_yaml"] 6 | -------------------------------------------------------------------------------- /projects/twhin/scripts/run_in_docker.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | 3 | torchrun \ 4 | --standalone \ 5 | --nnodes 1 \ 6 | --nproc_per_node 2 \ 7 | /usr/src/app/tml/projects/twhin/run.py \ 8 | --config_yaml_path="/usr/src/app/tml/projects/twhin/config/local.yaml" \ 9 | --save_dir="/some/save/dir" 10 | -------------------------------------------------------------------------------- /projects/twhin/scripts/docker_run.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | 3 | docker run -it --rm \ 4 | -v $HOME/workspace/tml:/usr/src/app/tml \ 5 | -v $HOME/.config:/root/.config \ 6 | -w /usr/src/app \ 7 | -e PYTHONPATH="/usr/src/app/" \ 8 | --network host \ 9 | -e SPEC_TYPE=chief \ 10 | local/torch \ 11 | bash tml/projects/twhin/scripts/run_in_docker.sh 12 | -------------------------------------------------------------------------------- /projects/twhin/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchmetrics as tm 3 | 4 | import tml.core.metrics as core_metrics 5 | 6 | 7 | def create_metrics( 8 | device: torch.device, 9 | ): 10 | metrics = dict() 11 | metrics.update( 12 | { 13 | "AUC": core_metrics.Auc(128), 14 | } 15 | ) 16 | metrics = tm.MetricCollection(metrics).to(device) 17 | return metrics 18 | -------------------------------------------------------------------------------- /common/test_device.py: -------------------------------------------------------------------------------- 1 | """Minimal test for device. 2 | 3 | Mostly a test that this can be imported properly even tho moved. 4 | """ 5 | from unittest.mock import patch 6 | 7 | import tml.common.device as device_utils 8 | 9 | 10 | def test_device(): 11 | with patch("tml.common.device.dist.init_process_group"): 12 | device = device_utils.setup_and_get_device(tf_ok=False) 13 | assert device.type == "cpu" 14 | -------------------------------------------------------------------------------- /projects/home/recap/script/create_random_data.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Runs from inside venv 4 | 5 | rm -rf $HOME/tmp/runs/recap_local_random_data 6 | python -m tml.machines.is_venv || exit 1 7 | export TML_BASE="$(git rev-parse --show-toplevel)" 8 | 9 | mkdir -p $HOME/tmp/recap_local_random_data 10 | python projects/home/recap/data/generate_random_data.py --config_path $(pwd)/projects/home/recap/config/local_prod.yaml 11 | -------------------------------------------------------------------------------- /projects/twhin/data/config.py: -------------------------------------------------------------------------------- 1 | from tml.core.config import base_config 2 | 3 | import pydantic 4 | 5 | 6 | class TwhinDataConfig(base_config.BaseConfig): 7 | data_root: str 8 | per_replica_batch_size: pydantic.PositiveInt 9 | global_negatives: int 10 | in_batch_negatives: int 11 | limit: pydantic.PositiveInt 12 | offset: pydantic.PositiveInt = pydantic.Field( 13 | None, description="The offset to start reading from." 14 | ) 15 | -------------------------------------------------------------------------------- /common/filesystem/test_infer_fs.py: -------------------------------------------------------------------------------- 1 | """Minimal test for infer_fs. 2 | 3 | Mostly a test that it returns an object 4 | """ 5 | from tml.common.filesystem import infer_fs 6 | 7 | 8 | def test_infer_fs(): 9 | local_path = "/tmp/local_path" 10 | gcs_path = "gs://somebucket/somepath" 11 | 12 | local_fs = infer_fs(local_path) 13 | gcs_fs = infer_fs(gcs_path) 14 | 15 | # This should return two different objects 16 | assert local_fs != gcs_fs 17 | -------------------------------------------------------------------------------- /projects/home/recap/script/run_local.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Runs from inside venv 4 | rm -rf $HOME/tmp/runs/recap_local_debug 5 | mkdir -p $HOME/tmp/runs/recap_local_debug 6 | python -m tml.machines.is_venv || exit 1 7 | export TML_BASE="$(git rev-parse --show-toplevel)" 8 | 9 | torchrun \ 10 | --standalone \ 11 | --nnodes 1 \ 12 | --nproc_per_node 1 \ 13 | projects/home/recap/main.py \ 14 | --config_path $(pwd)/projects/home/recap/config/local_prod.yaml \ 15 | $@ 16 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Mac 2 | .DS_Store 3 | 4 | # Vim 5 | *.py.swp 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # Installer logs 29 | pip-log.txt 30 | pip-delete-this-directory.txt 31 | 32 | # Unit test / coverage reports 33 | .hypothesis 34 | 35 | venv 36 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pausan/cblack 3 | rev: release-22.3.0 4 | hooks: 5 | - id: cblack 6 | name: cblack 7 | description: "Black: The uncompromising Python code formatter - 2 space indent fork" 8 | entry: cblack . -l 100 9 | - repo: https://github.com/pre-commit/pre-commit-hooks 10 | rev: v2.3.0 11 | hooks: 12 | - id: trailing-whitespace 13 | - id: end-of-file-fixer 14 | - id: check-yaml 15 | - id: check-added-large-files 16 | - id: check-merge-conflict 17 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This project open sources some of the ML models used at Twitter. 2 | 3 | Currently these are: 4 | 5 | 1. The "For You" Heavy Ranker (projects/home/recap). 6 | 7 | 2. TwHIN embeddings (projects/twhin) https://arxiv.org/abs/2202.05387 8 | 9 | 10 | This project can be run inside a python virtualenv. We have only tried this on Linux machines and because we use torchrec it works best with an Nvidia GPU. To setup run 11 | 12 | `./images/init_venv.sh` (Linux only). 13 | 14 | The READMEs of each project contain instructions about how to run each project. 15 | -------------------------------------------------------------------------------- /projects/twhin/config.py: -------------------------------------------------------------------------------- 1 | from tml.core.config import base_config 2 | from tml.projects.twhin.data.config import TwhinDataConfig 3 | from tml.projects.twhin.models.config import TwhinModelConfig 4 | from tml.core.config.training import RuntimeConfig, TrainingConfig 5 | 6 | import pydantic 7 | 8 | 9 | class TwhinConfig(base_config.BaseConfig): 10 | runtime: RuntimeConfig = pydantic.Field(RuntimeConfig()) 11 | training: TrainingConfig = pydantic.Field(TrainingConfig()) 12 | model: TwhinModelConfig 13 | train_data: TwhinDataConfig 14 | validation_data: TwhinDataConfig 15 | -------------------------------------------------------------------------------- /machines/is_venv.py: -------------------------------------------------------------------------------- 1 | """This is intended to be run as a module. 2 | e.g. python -m tml.machines.is_venv 3 | 4 | Exits with 0 ii running in venv, otherwise 1. 5 | """ 6 | 7 | import sys 8 | import logging 9 | 10 | 11 | def is_venv(): 12 | # See https://stackoverflow.com/questions/1871549/determine-if-python-is-running-inside-virtualenv 13 | return sys.base_prefix != sys.prefix 14 | 15 | 16 | def _main(): 17 | if is_venv(): 18 | logging.info("In venv %s", sys.prefix) 19 | sys.exit(0) 20 | else: 21 | logging.error("Not in venv") 22 | sys.exit(1) 23 | 24 | 25 | if __name__ == "__main__": 26 | _main() 27 | -------------------------------------------------------------------------------- /ml_logging/test_torch_logging.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from tml.ml_logging.torch_logging import logging 4 | 5 | 6 | class Testtlogging(unittest.TestCase): 7 | def test_warn_once(self): 8 | with self.assertLogs(level="INFO") as captured_logs: 9 | logging.info("first info") 10 | logging.warning("first warning") 11 | logging.warning("first warning") 12 | logging.info("second info") 13 | 14 | self.assertEqual( 15 | captured_logs.output, 16 | [ 17 | "INFO:absl:first info", 18 | "WARNING:absl:first warning", 19 | "INFO:absl:second info", 20 | ], 21 | ) 22 | -------------------------------------------------------------------------------- /projects/home/recap/model/numeric_calibration.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class NumericCalibration(torch.nn.Module): 5 | def __init__( 6 | self, 7 | pos_downsampling_rate: float, 8 | neg_downsampling_rate: float, 9 | ): 10 | super().__init__() 11 | 12 | # Using buffer to make sure they are on correct device (and not moved every time). 13 | # Will also be part of state_dict. 14 | self.register_buffer( 15 | "ratio", torch.as_tensor(neg_downsampling_rate / pos_downsampling_rate), persistent=True 16 | ) 17 | 18 | def forward(self, probs: torch.Tensor): 19 | return probs * self.ratio / (1.0 - probs + (self.ratio * probs)) 20 | -------------------------------------------------------------------------------- /common/filesystem/util.py: -------------------------------------------------------------------------------- 1 | """Utilities for interacting with the file systems.""" 2 | from fsspec.implementations.local import LocalFileSystem 3 | import gcsfs 4 | 5 | 6 | GCS_FS = gcsfs.GCSFileSystem(cache_timeout=-1) 7 | LOCAL_FS = LocalFileSystem() 8 | 9 | 10 | def infer_fs(path: str): 11 | if path.startswith("gs://"): 12 | return GCS_FS 13 | elif path.startswith("hdfs://"): 14 | # We can probably use pyarrow HDFS to support this. 15 | raise NotImplementedError("HDFS not yet supported") 16 | else: 17 | return LOCAL_FS 18 | 19 | 20 | def is_local_fs(fs): 21 | return fs == LOCAL_FS 22 | 23 | 24 | def is_gcs_fs(fs): 25 | return fs == GCS_FS 26 | -------------------------------------------------------------------------------- /core/config/test_config_load.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from tml.core.config import BaseConfig, load_config_from_yaml 4 | 5 | import pydantic 6 | import getpass 7 | import pydantic 8 | 9 | 10 | class _PointlessConfig(BaseConfig): 11 | a: int 12 | user: str 13 | 14 | 15 | def test_load_config_from_yaml(tmp_path): 16 | yaml_path = tmp_path.joinpath("test.yaml").as_posix() 17 | with open(yaml_path, "w") as yaml_file: 18 | yaml_file.write("""a: 3\nuser: ${USER}\n""") 19 | 20 | pointless_config = load_config_from_yaml(_PointlessConfig, yaml_path) 21 | 22 | assert pointless_config.a == 3 23 | assert pointless_config.user == getpass.getuser() 24 | -------------------------------------------------------------------------------- /projects/twhin/data/data.py: -------------------------------------------------------------------------------- 1 | from tml.projects.twhin.data.config import TwhinDataConfig 2 | from tml.projects.twhin.models.config import TwhinModelConfig 3 | from tml.projects.twhin.data.edges import EdgesDataset 4 | 5 | 6 | def create_dataset(data_config: TwhinDataConfig, model_config: TwhinModelConfig): 7 | tables = model_config.embeddings.tables 8 | table_sizes = {table.name: table.num_embeddings for table in tables} 9 | relations = model_config.relations 10 | 11 | pos_batch_size = data_config.per_replica_batch_size 12 | 13 | return EdgesDataset( 14 | file_pattern=data_config.data_root, 15 | relations=relations, 16 | table_sizes=table_sizes, 17 | batch_size=pos_batch_size, 18 | ) 19 | -------------------------------------------------------------------------------- /common/device.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.distributed as dist 5 | 6 | 7 | def maybe_setup_tensorflow(): 8 | try: 9 | import tensorflow as tf 10 | except ImportError: 11 | pass 12 | else: 13 | tf.config.set_visible_devices([], "GPU") # disable tf gpu 14 | 15 | 16 | def setup_and_get_device(tf_ok: bool = True) -> torch.device: 17 | if tf_ok: 18 | maybe_setup_tensorflow() 19 | 20 | device = torch.device("cpu") 21 | backend = "gloo" 22 | if torch.cuda.is_available(): 23 | rank = os.environ["LOCAL_RANK"] 24 | device = torch.device(f"cuda:{rank}") 25 | backend = "nccl" 26 | torch.cuda.set_device(device) 27 | if not torch.distributed.is_initialized(): 28 | dist.init_process_group(backend) 29 | 30 | return device 31 | -------------------------------------------------------------------------------- /images/init_venv.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | 3 | if [[ "$(uname)" == "Darwin" ]]; then 4 | echo "Only supported on Linux." 5 | exit 1 6 | fi 7 | 8 | # You may need to point this to a version of python 3.10 9 | PYTHONBIN="/opt/ee/python/3.10/bin/python3.10" 10 | echo Using "PYTHONBIN=$PYTHONBIN" 11 | 12 | # Put venv in tmp, these things are not made to last, just rebuild. 13 | VENV_PATH="$HOME/tml_venv" 14 | rm -rf "$VENV_PATH" 15 | "$PYTHONBIN" -m venv "$VENV_PATH" 16 | 17 | # shellcheck source=/dev/null 18 | . "$VENV_PATH/bin/activate" 19 | 20 | pip --require-virtual install -U pip 21 | pip --require-virtualenv install --no-deps -r images/requirements.txt 22 | 23 | ln -s "$(pwd)" "$VENV_PATH/lib/python3.10/site-packages/tml" 24 | 25 | echo "Now run source ${VENV_PATH}/bin/activate" to get going. 26 | -------------------------------------------------------------------------------- /common/testing_utils.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | import datetime 3 | import os 4 | from unittest.mock import patch 5 | 6 | import torch.distributed as dist 7 | from tml.ml_logging.torch_logging import logging 8 | 9 | 10 | MOCK_ENV = { 11 | "LOCAL_RANK": "0", 12 | "WORLD_SIZE": "1", 13 | "LOCAL_WORLD_SIZE": "1", 14 | "MASTER_ADDR": "localhost", 15 | "MASTER_PORT": "29501", 16 | "RANK": "0", 17 | } 18 | 19 | 20 | @contextmanager 21 | def mock_pg(): 22 | with patch.dict(os.environ, MOCK_ENV): 23 | try: 24 | dist.init_process_group( 25 | backend="gloo", 26 | timeout=datetime.timedelta(1), 27 | ) 28 | yield 29 | except: 30 | dist.destroy_process_group() 31 | raise 32 | finally: 33 | dist.destroy_process_group() 34 | -------------------------------------------------------------------------------- /core/config/config_load.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import string 3 | import getpass 4 | import os 5 | from typing import Type 6 | 7 | from tml.core.config.base_config import BaseConfig 8 | 9 | 10 | def load_config_from_yaml(config_type: Type[BaseConfig], yaml_path: str): 11 | """Recommend method to load a config file (a yaml file) and parse it. 12 | 13 | Because we have a shared filesystem the recommended route to running jobs it put modified config 14 | files with the desired parameters somewhere on the filesytem and run jobs pointing to them. 15 | """ 16 | 17 | def _substitute(s): 18 | return string.Template(s).safe_substitute(os.environ, USER=getpass.getuser()) 19 | 20 | with open(yaml_path, "r") as f: 21 | raw_contents = f.read() 22 | obj = yaml.safe_load(_substitute(raw_contents)) 23 | 24 | return config_type.parse_obj(obj) 25 | -------------------------------------------------------------------------------- /ml_logging/absl_logging.py: -------------------------------------------------------------------------------- 1 | """Sets up logging through absl for training usage. 2 | 3 | - Redirects logging to sys.stdout so that severity levels in GCP Stackdriver are accurate. 4 | 5 | Usage: 6 | >>> from twitter.ml.logging.absl_logging import logging 7 | >>> logging.info(f"Properly logged as INFO level in GCP Stackdriver.") 8 | 9 | """ 10 | import logging as py_logging 11 | import sys 12 | 13 | from absl import logging as logging 14 | 15 | 16 | def setup_absl_logging(): 17 | """Make sure that absl logging pushes to stdout rather than stderr.""" 18 | logging.get_absl_handler().python_handler.stream = sys.stdout 19 | formatter = py_logging.Formatter( 20 | fmt="[%(module)s.%(funcName)s:%(lineno)s - %(levelname)s] %(message)s" 21 | ) 22 | logging.get_absl_handler().setFormatter(formatter) 23 | logging.set_verbosity(logging.INFO) 24 | 25 | 26 | setup_absl_logging() 27 | -------------------------------------------------------------------------------- /common/wandb.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List 2 | 3 | import tml.core.config as base_config 4 | 5 | import pydantic 6 | 7 | 8 | class WandbConfig(base_config.BaseConfig): 9 | host: str = pydantic.Field( 10 | "https://https--wandb--prod--wandb.service.qus1.twitter.biz/", 11 | description="Host of Weights and Biases instance, passed to login.", 12 | ) 13 | key_path: str = pydantic.Field(description="Path to key file.") 14 | 15 | name: str = pydantic.Field(None, description="Name of the experiment, passed to init.") 16 | entity: str = pydantic.Field(None, description="Name of user/service account, passed to init.") 17 | project: str = pydantic.Field(None, description="Name of wandb project, passed to init.") 18 | tags: List[str] = pydantic.Field([], description="List of tags, passed to init.") 19 | notes: str = pydantic.Field(None, description="Notes, passed to init.") 20 | metadata: Dict[str, Any] = pydantic.Field(None, description="Additional metadata to log.") 21 | -------------------------------------------------------------------------------- /core/config/base_config_test.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from tml.core.config import BaseConfig 4 | 5 | import pydantic 6 | 7 | 8 | class BaseConfigTest(TestCase): 9 | def test_extra_forbidden(self): 10 | class Config(BaseConfig): 11 | x: int 12 | 13 | Config(x=1) 14 | with self.assertRaises(pydantic.ValidationError): 15 | Config(x=1, y=2) 16 | 17 | def test_one_of(self): 18 | class Config(BaseConfig): 19 | x: int = pydantic.Field(None, one_of="f") 20 | y: int = pydantic.Field(None, one_of="f") 21 | 22 | with self.assertRaises(pydantic.ValidationError): 23 | Config() 24 | Config(x=1) 25 | Config(y=1) 26 | with self.assertRaises(pydantic.ValidationError): 27 | Config(x=1, y=3) 28 | 29 | def test_at_most_one_of(self): 30 | class Config(BaseConfig): 31 | x: int = pydantic.Field(None, at_most_one_of="f") 32 | y: str = pydantic.Field(None, at_most_one_of="f") 33 | 34 | Config() 35 | Config(x=1) 36 | Config(y="a") 37 | with self.assertRaises(pydantic.ValidationError): 38 | Config(x=1, y="a") 39 | -------------------------------------------------------------------------------- /projects/twhin/test_optimizer.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import unittest 3 | 4 | from tml.projects.twhin.models.models import TwhinModel, apply_optimizers 5 | from tml.projects.twhin.models.test_models import twhin_model_config, twhin_data_config 6 | from tml.projects.twhin.optimizer import build_optimizer 7 | from tml.model import maybe_shard_model 8 | from tml.common.testing_utils import mock_pg 9 | 10 | 11 | import torch 12 | from torch.nn import functional as F 13 | 14 | 15 | def test_twhin_optimizer(): 16 | model_config = twhin_model_config() 17 | data_config = twhin_data_config() 18 | 19 | loss_fn = F.binary_cross_entropy_with_logits 20 | with mock_pg(): 21 | model = TwhinModel(model_config, data_config) 22 | apply_optimizers(model, model_config) 23 | model = maybe_shard_model(model, device=torch.device("cpu")) 24 | 25 | optimizer, _ = build_optimizer(model, model_config) 26 | 27 | # make sure there is one combined fused optimizer and one translation optimizer 28 | assert len(optimizer.optimizers) == 2 29 | fused_opt_tup, _ = optimizer.optimizers 30 | _, fused_opt = fused_opt_tup 31 | 32 | # make sure there are two tables for which the fused opt has parameters 33 | assert len(fused_opt.param_groups) == 2 34 | -------------------------------------------------------------------------------- /projects/twhin/README.md: -------------------------------------------------------------------------------- 1 | Twhin in torchrec 2 | 3 | This project contains code for pretraining dense vector embedding features for Twitter entities. Within Twitter, these embeddings are used for candidate retrieval and as model features in a variety of recommender system models. 4 | 5 | We obtain entity embeddings based on a variety of graph data within Twitter such as: 6 | "User follows User" 7 | "User favorites Tweet" 8 | "User clicks Advertisement" 9 | 10 | While we cannot release the graph data used to train TwHIN embeddings due to privacy restrictions, heavily subsampled, anonymized open-sourced graph data can used: 11 | https://huggingface.co/datasets/Twitter/TwitterFollowGraph 12 | https://huggingface.co/datasets/Twitter/TwitterFaveGraph 13 | 14 | The code expects parquet files with three columns: lhs, rel, rhs that refer to the vocab index of the left-hand-side node, relation type, and right-hand-side node of each edge in a graph respectively. 15 | 16 | The location of the data must be specified in the configuration yaml files in projects/twhin/configs. 17 | 18 | 19 | Workflow 20 | ======== 21 | - Build local development images `./scripts/build_images.sh` 22 | - Run with `./scripts/docker_run.sh` 23 | - Iterate in image with `./scripts/idocker.sh` 24 | - Run tests with `./scripts/docker_test.sh` 25 | -------------------------------------------------------------------------------- /common/utils.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import getpass 3 | import os 4 | import string 5 | from typing import Tuple, Type, TypeVar 6 | 7 | from tml.core.config import base_config 8 | 9 | import fsspec 10 | 11 | C = TypeVar("C", bound=base_config.BaseConfig) 12 | 13 | 14 | def _read_file(f): 15 | with fsspec.open(f) as f: 16 | return f.read() 17 | 18 | 19 | def setup_configuration( 20 | config_type: Type[C], 21 | yaml_path: str, 22 | substitute_env_variable: bool = False, 23 | ) -> Tuple[C, str]: 24 | """Resolves a config at a yaml path. 25 | 26 | Args: 27 | config_type: Pydantic config class to load. 28 | yaml_path: yaml path of the config file. 29 | substitute_env_variable: If True substitute string in the format $VAR or ${VAR} by their 30 | environment variable value whenever possible. If an environment variable doesn't exist, 31 | the string is left unchanged. 32 | 33 | Returns: 34 | The pydantic config object. 35 | """ 36 | 37 | def _substitute(s): 38 | if substitute_env_variable: 39 | return string.Template(s).safe_substitute(os.environ, USER=getpass.getuser()) 40 | return s 41 | 42 | assert config_type is not None, "can't use all_config without config_type" 43 | content = _substitute(yaml.safe_load(_read_file(yaml_path))) 44 | return config_type.parse_obj(content) 45 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: Python package 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | matrix: 10 | python-version: ["3.10"] 11 | 12 | steps: 13 | - uses: actions/checkout@v3 14 | # - uses: pre-commit/action@v3.0.0 15 | # name: Run pre-commit checks (pylint/yapf/isort) 16 | # env: 17 | # SKIP: insert-license 18 | # with: 19 | # extra_args: --hook-stage push --all-files 20 | - uses: actions/setup-python@v4 21 | with: 22 | python-version: "3.10" 23 | cache: "pip" # caching pip dependencies 24 | - name: install packages 25 | run: | 26 | /usr/bin/python -m pip install --upgrade pip 27 | pip install --no-deps -r images/requirements.txt 28 | # - name: ssh access 29 | # uses: lhotari/action-upterm@v1 30 | # with: 31 | # limit-access-to-actor: true 32 | # limit-access-to-users: arashd 33 | - name: run tests 34 | run: | 35 | # Environment variables are reset in between steps. 36 | mkdir /tmp/github_testing 37 | ln -s $GITHUB_WORKSPACE /tmp/github_testing/tml 38 | export PYTHONPATH="/tmp/github_testing:$PYTHONPATH" 39 | pytest -vv 40 | -------------------------------------------------------------------------------- /machines/list_ops.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple str.split() parsing of input string 3 | 4 | usage example: 5 | python list_ops.py --input_list=$INPUT [--sep=","] [--op=] [--elem=$INDEX] 6 | 7 | Args: 8 | - input_list: input string 9 | - sep (default ","): separator string 10 | - elem (default 0): integer index 11 | - op (default "select"): either `len` or `select` 12 | - len: prints len(input_list.split(sep)) 13 | - select: prints input_list.split(sep)[elem] 14 | 15 | Typical usage would be in a bash script, e.g.: 16 | 17 | LIST_LEN=$(python list_ops.py --input_list=$INPUT --op=len) 18 | 19 | """ 20 | import tml.machines.environment as env 21 | 22 | from absl import app, flags 23 | 24 | 25 | FLAGS = flags.FLAGS 26 | flags.DEFINE_string("input_list", None, "string to parse as list") 27 | flags.DEFINE_integer("elem", 0, "which element to take") 28 | flags.DEFINE_string("sep", ",", "separator") 29 | flags.DEFINE_string("op", "select", "operation to do") 30 | 31 | 32 | def main(argv): 33 | split_list = FLAGS.input_list.split(FLAGS.sep) 34 | if FLAGS.op == "select": 35 | print(split_list[FLAGS.elem], flush=True) 36 | elif FLAGS.op == "len": 37 | print(len(split_list), flush=True) 38 | else: 39 | raise ValueError(f"operation {FLAGS.op} not recognized.") 40 | 41 | 42 | if __name__ == "__main__": 43 | app.run(main) 44 | -------------------------------------------------------------------------------- /projects/home/recap/optimizer/config.py: -------------------------------------------------------------------------------- 1 | """Optimization configurations for models.""" 2 | 3 | import typing 4 | 5 | import tml.core.config as base_config 6 | import tml.optimizers.config as optimizers_config_mod 7 | 8 | import pydantic 9 | 10 | 11 | class RecapAdamConfig(base_config.BaseConfig): 12 | beta_1: float = 0.9 # Momentum term. 13 | beta_2: float = 0.999 # Exponential weighted decay factor. 14 | epsilon: float = 1e-7 # Numerical stability in denominator. 15 | 16 | 17 | class MultiTaskLearningRates(base_config.BaseConfig): 18 | tower_learning_rates: typing.Dict[str, optimizers_config_mod.LearningRate] = pydantic.Field( 19 | description="Learning rates for different towers of the model." 20 | ) 21 | 22 | backbone_learning_rate: optimizers_config_mod.LearningRate = pydantic.Field( 23 | None, description="Learning rate for backbone of the model." 24 | ) 25 | 26 | 27 | class RecapOptimizerConfig(base_config.BaseConfig): 28 | multi_task_learning_rates: MultiTaskLearningRates = pydantic.Field( 29 | None, description="Multiple learning rates for different tasks.", one_of="lr" 30 | ) 31 | 32 | single_task_learning_rate: optimizers_config_mod.LearningRate = pydantic.Field( 33 | None, description="Single task learning rates", one_of="lr" 34 | ) 35 | 36 | adam: RecapAdamConfig = pydantic.Field(one_of="optimizer") 37 | -------------------------------------------------------------------------------- /core/debug_training_loop.py: -------------------------------------------------------------------------------- 1 | """This is a very limited feature training loop useful for interactive debugging. 2 | 3 | It is not intended for actual model tranining (it is not fast, doesn't compile the model). 4 | It does not support checkpointing. 5 | 6 | suggested use: 7 | 8 | from tml.core import debug_training_loop 9 | debug_training_loop.train(...) 10 | """ 11 | 12 | from typing import Iterable, Optional, Dict, Callable, List 13 | import torch 14 | from torch.optim.lr_scheduler import _LRScheduler 15 | import torchmetrics as tm 16 | 17 | from tml.ml_logging.torch_logging import logging 18 | 19 | 20 | def train( 21 | model: torch.nn.Module, 22 | optimizer: torch.optim.Optimizer, 23 | train_steps: int, 24 | dataset: Iterable, 25 | scheduler: _LRScheduler = None, 26 | # Accept any arguments (to be compatible with the real training loop) 27 | # but just ignore them. 28 | *args, 29 | **kwargs, 30 | ) -> None: 31 | 32 | logging.warning("Running debug training loop, don't use for model training.") 33 | 34 | data_iter = iter(dataset) 35 | for step in range(0, train_steps + 1): 36 | x = next(data_iter) 37 | optimizer.zero_grad() 38 | loss, outputs = model.forward(x) 39 | loss.backward() 40 | optimizer.step() 41 | 42 | if scheduler: 43 | scheduler.step() 44 | 45 | logging.info(f"Step {step} completed. Loss = {loss}") 46 | -------------------------------------------------------------------------------- /machines/get_env.py: -------------------------------------------------------------------------------- 1 | import tml.machines.environment as env 2 | 3 | from absl import app, flags 4 | 5 | 6 | FLAGS = flags.FLAGS 7 | flags.DEFINE_string("property", None, "Which property of the current environment to fetch.") 8 | 9 | 10 | def main(argv): 11 | if FLAGS.property == "using_dds": 12 | print(f"{env.has_readers()}", flush=True) 13 | if FLAGS.property == "has_readers": 14 | print(f"{env.has_readers()}", flush=True) 15 | elif FLAGS.property == "get_task_type": 16 | print(f"{env.get_task_type()}", flush=True) 17 | elif FLAGS.property == "is_datasetworker": 18 | print(f"{env.is_reader()}", flush=True) 19 | elif FLAGS.property == "is_dds_dispatcher": 20 | print(f"{env.is_dispatcher()}", flush=True) 21 | elif FLAGS.property == "get_task_index": 22 | print(f"{env.get_task_index()}", flush=True) 23 | elif FLAGS.property == "get_dataset_service": 24 | print(f"{env.get_dds()}", flush=True) 25 | elif FLAGS.property == "get_dds_dispatcher_address": 26 | print(f"{env.get_dds_dispatcher_address()}", flush=True) 27 | elif FLAGS.property == "get_dds_worker_address": 28 | print(f"{env.get_dds_worker_address()}", flush=True) 29 | elif FLAGS.property == "get_dds_port": 30 | print(f"{env.get_reader_port()}", flush=True) 31 | elif FLAGS.property == "get_dds_journaling_dir": 32 | print(f"{env.get_dds_journaling_dir()}", flush=True) 33 | elif FLAGS.property == "should_start_dds": 34 | print(env.is_reader() or env.is_dispatcher(), flush=True) 35 | 36 | 37 | if __name__ == "__main__": 38 | app.run(main) 39 | -------------------------------------------------------------------------------- /core/config/training.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional 2 | 3 | from tml.common.wandb import WandbConfig 4 | from tml.core.config import base_config 5 | from tml.projects.twhin.data.config import TwhinDataConfig 6 | from tml.projects.twhin.models.config import TwhinModelConfig 7 | 8 | import pydantic 9 | 10 | 11 | class RuntimeConfig(base_config.BaseConfig): 12 | wandb: WandbConfig = pydantic.Field(None) 13 | enable_tensorfloat32: bool = pydantic.Field( 14 | False, description="Use tensorfloat32 if on Ampere devices." 15 | ) 16 | enable_amp: bool = pydantic.Field(False, description="Enable automatic mixed precision.") 17 | 18 | 19 | class TrainingConfig(base_config.BaseConfig): 20 | save_dir: str = pydantic.Field("/tmp/model", description="Directory to save checkpoints.") 21 | num_train_steps: pydantic.PositiveInt = 10000 22 | initial_checkpoint_dir: str = pydantic.Field( 23 | None, description="Directory of initial checkpoints", at_most_one_of="initialization" 24 | ) 25 | checkpoint_every_n: pydantic.PositiveInt = 1000 26 | checkpoint_max_to_keep: pydantic.PositiveInt = pydantic.Field( 27 | None, description="Maximum number of checkpoints to keep. Defaults to keeping all." 28 | ) 29 | train_log_every_n: pydantic.PositiveInt = 1000 30 | num_eval_steps: int = pydantic.Field( 31 | 16384, description="Number of evaluation steps. If < 0 the entire dataset will be used." 32 | ) 33 | eval_log_every_n: pydantic.PositiveInt = 5000 34 | 35 | eval_timeout_in_s: pydantic.PositiveFloat = 60 * 60 36 | 37 | gradient_accumulation: int = pydantic.Field( 38 | None, description="Number of replica steps to accumulate gradients." 39 | ) 40 | num_epochs: pydantic.PositiveInt = 1 41 | -------------------------------------------------------------------------------- /projects/twhin/config/local.yaml: -------------------------------------------------------------------------------- 1 | runtime: 2 | enable_amp: false 3 | training: 4 | save_dir: "/tmp/model" 5 | num_train_steps: 100000 6 | checkpoint_every_n: 100000 7 | train_log_every_n: 10 8 | num_eval_steps: 1000 9 | eval_log_every_n: 500 10 | eval_timeout_in_s: 10000 11 | num_epochs: 5 12 | model: 13 | translation_optimizer: 14 | sgd: 15 | lr: 0.05 16 | learning_rate: 17 | constant: 0.05 18 | embeddings: 19 | tables: 20 | - name: user 21 | num_embeddings: 424_241 22 | embedding_dim: 4 23 | data_type: fp32 24 | optimizer: 25 | sgd: 26 | lr: 0.01 27 | learning_rate: 28 | constant: 0.01 29 | - name: tweet 30 | num_embeddings: 72_543 31 | embedding_dim: 4 32 | data_type: fp32 33 | optimizer: 34 | sgd: 35 | lr: 0.005 36 | learning_rate: 37 | constant: 0.005 38 | relations: 39 | - name: fav 40 | lhs: user 41 | rhs: tweet 42 | operator: translation 43 | - name: reply 44 | lhs: user 45 | rhs: tweet 46 | operator: translation 47 | - name: retweet 48 | lhs: user 49 | rhs: tweet 50 | operator: translation 51 | - name: magic_recs 52 | lhs: user 53 | rhs: tweet 54 | operator: translation 55 | train_data: 56 | data_root: "gs://follows_tml_01/tweet_eng/2023-01-23/large/edges/*" 57 | per_replica_batch_size: 500 58 | global_negatives: 0 59 | in_batch_negatives: 10 60 | limit: 9990 61 | validation_data: 62 | data_root: "gs://follows_tml_01/tweet_eng/2023-01-23/large/edges/*" 63 | per_replica_batch_size: 500 64 | global_negatives: 0 65 | in_batch_negatives: 10 66 | limit: 10 67 | offset: 9990 68 | -------------------------------------------------------------------------------- /LICENSE.torchrec: -------------------------------------------------------------------------------- 1 | A few files here (where it is specifically noted in comments) are based on code from torchrec but 2 | adapted for our use. Torchrec license is below: 3 | 4 | 5 | BSD 3-Clause License 6 | 7 | Copyright (c) Meta Platforms, Inc. and affiliates. 8 | All rights reserved. 9 | 10 | Redistribution and use in source and binary forms, with or without 11 | modification, are permitted provided that the following conditions are met: 12 | 13 | * Redistributions of source code must retain the above copyright notice, this 14 | list of conditions and the following disclaimer. 15 | 16 | * Redistributions in binary form must reproduce the above copyright notice, 17 | this list of conditions and the following disclaimer in the documentation 18 | and/or other materials provided with the distribution. 19 | 20 | * Neither the name of the copyright holder nor the names of its 21 | contributors may be used to endorse or promote products derived from 22 | this software without specific prior written permission. 23 | 24 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 25 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 26 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 27 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 28 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 29 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 30 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 31 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 32 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 33 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 34 | -------------------------------------------------------------------------------- /common/modules/embedding/embedding.py: -------------------------------------------------------------------------------- 1 | from tml.common.modules.embedding.config import LargeEmbeddingsConfig, DataType 2 | from tml.ml_logging.torch_logging import logging 3 | 4 | import torch 5 | from torch import nn 6 | import torchrec 7 | from torchrec.modules import embedding_configs 8 | from torchrec import EmbeddingBagConfig, EmbeddingBagCollection 9 | from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor 10 | import numpy as np 11 | 12 | 13 | class LargeEmbeddings(nn.Module): 14 | def __init__( 15 | self, 16 | large_embeddings_config: LargeEmbeddingsConfig, 17 | ): 18 | super().__init__() 19 | 20 | tables = [] 21 | for table in large_embeddings_config.tables: 22 | data_type = ( 23 | embedding_configs.DataType.FP32 24 | if (table.data_type == DataType.FP32) 25 | else embedding_configs.DataType.FP16 26 | ) 27 | 28 | tables.append( 29 | EmbeddingBagConfig( 30 | embedding_dim=table.embedding_dim, 31 | feature_names=[table.name], # restricted to 1 feature per table for now 32 | name=table.name, 33 | num_embeddings=table.num_embeddings, 34 | pooling=torchrec.PoolingType.SUM, 35 | data_type=data_type, 36 | ) 37 | ) 38 | 39 | self.ebc = EmbeddingBagCollection( 40 | device="meta", 41 | tables=tables, 42 | ) 43 | 44 | logging.info("********************** EBC named params are **********") 45 | logging.info(list(self.ebc.named_parameters())) 46 | 47 | # This hook is used to perform post-processing surgery 48 | # on large_embedding models to prep them for serving 49 | self.surgery_cut_point = torch.nn.Identity() 50 | 51 | def forward( 52 | self, 53 | sparse_features: KeyedJaggedTensor, 54 | ) -> KeyedTensor: 55 | pooled_embs = self.ebc(sparse_features) 56 | 57 | # a KeyedTensor 58 | return self.surgery_cut_point(pooled_embs) 59 | -------------------------------------------------------------------------------- /projects/home/recap/model/mlp.py: -------------------------------------------------------------------------------- 1 | """MLP feed forward stack in torch.""" 2 | 3 | from tml.projects.home.recap.model.config import MlpConfig 4 | 5 | import torch 6 | from absl import logging 7 | 8 | 9 | def _init_weights(module): 10 | if isinstance(module, torch.nn.Linear): 11 | torch.nn.init.xavier_uniform_(module.weight) 12 | torch.nn.init.constant_(module.bias, 0) 13 | 14 | 15 | class Mlp(torch.nn.Module): 16 | def __init__(self, in_features: int, mlp_config: MlpConfig): 17 | super().__init__() 18 | self._mlp_config = mlp_config 19 | input_size = in_features 20 | layer_sizes = mlp_config.layer_sizes 21 | modules = [] 22 | for layer_size in layer_sizes[:-1]: 23 | modules.append(torch.nn.Linear(input_size, layer_size, bias=True)) 24 | 25 | if mlp_config.batch_norm: 26 | modules.append( 27 | torch.nn.BatchNorm1d( 28 | layer_size, affine=mlp_config.batch_norm.affine, momentum=mlp_config.batch_norm.momentum 29 | ) 30 | ) 31 | 32 | modules.append(torch.nn.ReLU()) 33 | 34 | if mlp_config.dropout: 35 | modules.append(torch.nn.Dropout(mlp_config.dropout.rate)) 36 | 37 | input_size = layer_size 38 | modules.append(torch.nn.Linear(input_size, layer_sizes[-1], bias=True)) 39 | if mlp_config.final_layer_activation: 40 | modules.append(torch.nn.ReLU()) 41 | self.layers = torch.nn.ModuleList(modules) 42 | self.layers.apply(_init_weights) 43 | 44 | def forward(self, x: torch.Tensor) -> torch.Tensor: 45 | net = x 46 | for i, layer in enumerate(self.layers): 47 | net = layer(net) 48 | if i == 1: # Share the first (widest?) set of activations for other applications. 49 | shared_layer = net 50 | return {"output": net, "shared_layer": shared_layer} 51 | 52 | @property 53 | def shared_size(self): 54 | return self._mlp_config.layer_sizes[-1] 55 | 56 | @property 57 | def out_features(self): 58 | return self._mlp_config.layer_sizes[-1] 59 | -------------------------------------------------------------------------------- /projects/twhin/data/test_edges.py: -------------------------------------------------------------------------------- 1 | """Tests edges dataset functionality.""" 2 | 3 | from unittest.mock import patch 4 | import os 5 | import tempfile 6 | 7 | from tml.projects.twhin.data.edges import EdgesDataset 8 | from tml.projects.twhin.models.config import Relation 9 | 10 | from fsspec.implementations.local import LocalFileSystem 11 | import numpy as np 12 | import pyarrow as pa 13 | import pyarrow.compute as pc 14 | import pyarrow.parquet as pq 15 | import torch 16 | 17 | 18 | TABLE_SIZES = {"user": 16, "author": 32} 19 | RELATIONS = [ 20 | Relation(name="fav", lhs="user", rhs="author"), 21 | Relation(name="engaged_with_reply", lhs="author", rhs="user"), 22 | ] 23 | 24 | 25 | def test_gen(): 26 | import os 27 | import tempfile 28 | 29 | from fsspec.implementations.local import LocalFileSystem 30 | import pyarrow as pa 31 | import pyarrow.parquet as pq 32 | 33 | lhs = pa.array(np.arange(4)) 34 | rhs = pa.array(np.flip(np.arange(4))) 35 | rel = pa.array([0, 1, 0, 0]) 36 | names = ["lhs", "rhs", "rel"] 37 | 38 | with tempfile.TemporaryDirectory() as tmpdir: 39 | table = pa.Table.from_arrays([lhs, rhs, rel], names=names) 40 | writer = pq.ParquetWriter( 41 | os.path.join(tmpdir, "example.parquet"), 42 | table.schema, 43 | ) 44 | writer.write_table(table) 45 | writer.close() 46 | 47 | ds = EdgesDataset( 48 | file_pattern=os.path.join(tmpdir, "*"), 49 | table_sizes=TABLE_SIZES, 50 | relations=RELATIONS, 51 | batch_size=4, 52 | ) 53 | ds.FS = LocalFileSystem() 54 | 55 | dl = ds.dataloader() 56 | batch = next(iter(dl)) 57 | 58 | # labels should be positive 59 | labels = batch.labels 60 | assert (labels[:4] == 1).sum() == 4 61 | 62 | # make sure positive examples are what we expect 63 | kjt_values = batch.nodes.values() 64 | users, authors = torch.split(kjt_values, 4, dim=0) 65 | assert torch.equal(users[:4], torch.tensor([0, 2, 2, 3])) 66 | assert torch.equal(authors[:4], torch.tensor([3, 1, 1, 0])) 67 | -------------------------------------------------------------------------------- /projects/twhin/models/config.py: -------------------------------------------------------------------------------- 1 | import typing 2 | import enum 3 | 4 | from tml.common.modules.embedding.config import LargeEmbeddingsConfig 5 | from tml.core.config import base_config 6 | from tml.optimizers.config import OptimizerConfig 7 | 8 | import pydantic 9 | from pydantic import validator 10 | 11 | 12 | class TwhinEmbeddingsConfig(LargeEmbeddingsConfig): 13 | @validator("tables") 14 | def embedding_dims_match(cls, tables): 15 | embedding_dim = tables[0].embedding_dim 16 | data_type = tables[0].data_type 17 | for table in tables: 18 | assert table.embedding_dim == embedding_dim, "Embedding dimensions for all nodes must match." 19 | assert table.data_type == data_type, "Data types for all nodes must match." 20 | return tables 21 | 22 | 23 | class Operator(str, enum.Enum): 24 | TRANSLATION = "translation" 25 | 26 | 27 | class Relation(pydantic.BaseModel): 28 | """graph relationship properties and operator""" 29 | 30 | name: str = pydantic.Field(..., description="Relationship name.") 31 | lhs: str = pydantic.Field( 32 | ..., 33 | description="Name of the entity on the left-hand-side of this relation. Must match a table name.", 34 | ) 35 | rhs: str = pydantic.Field( 36 | ..., 37 | description="Name of the entity on the right-hand-side of this relation. Must match a table name.", 38 | ) 39 | operator: Operator = pydantic.Field( 40 | Operator.TRANSLATION, description="Transformation to apply to lhs embedding before dot product." 41 | ) 42 | 43 | 44 | class TwhinModelConfig(base_config.BaseConfig): 45 | embeddings: TwhinEmbeddingsConfig 46 | relations: typing.List[Relation] 47 | translation_optimizer: OptimizerConfig 48 | 49 | @validator("relations", each_item=True) 50 | def valid_node_types(cls, relation, values, **kwargs): 51 | table_names = [table.name for table in values["embeddings"].tables] 52 | assert relation.lhs in table_names, f"Invalid lhs node type: {relation.lhs}" 53 | assert relation.rhs in table_names, f"Invalid rhs node type: {relation.rhs}" 54 | return relation 55 | -------------------------------------------------------------------------------- /common/modules/embedding/config.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from enum import Enum 3 | 4 | import tml.core.config as base_config 5 | from tml.optimizers.config import OptimizerConfig 6 | 7 | import pydantic 8 | 9 | 10 | class DataType(str, Enum): 11 | FP32 = "fp32" 12 | FP16 = "fp16" 13 | 14 | 15 | class EmbeddingSnapshot(base_config.BaseConfig): 16 | """Configuration for Embedding snapshot""" 17 | 18 | emb_name: str = pydantic.Field( 19 | ..., description="Name of the embedding table from the loaded snapshot" 20 | ) 21 | embedding_snapshot_uri: str = pydantic.Field( 22 | ..., description="Path to torchsnapshot of the embedding" 23 | ) 24 | 25 | 26 | class EmbeddingBagConfig(base_config.BaseConfig): 27 | """Configuration for EmbeddingBag.""" 28 | 29 | name: str = pydantic.Field(..., description="name of embedding bag") 30 | num_embeddings: int = pydantic.Field(..., description="size of embedding dictionary") 31 | embedding_dim: int = pydantic.Field(..., description="size of each embedding vector") 32 | pretrained: EmbeddingSnapshot = pydantic.Field(None, description="Snapshot properties") 33 | vocab: str = pydantic.Field( 34 | None, description="Directory to parquet files of mapping from entity ID to table index." 35 | ) 36 | # make sure to use an optimizer that matches: 37 | # https://github.com/pytorch/FBGEMM/blob/4c58137529d221390575e47e88d3c05ce65b66fd/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py#L15 38 | optimizer: OptimizerConfig 39 | data_type: DataType 40 | 41 | 42 | class LargeEmbeddingsConfig(base_config.BaseConfig): 43 | """Configuration for EmbeddingBagCollection. 44 | 45 | The tables listed in this config are gathered into a single torchrec EmbeddingBagCollection. 46 | """ 47 | 48 | tables: List[EmbeddingBagConfig] = pydantic.Field(..., description="list of embedding tables") 49 | tables_to_log: List[str] = pydantic.Field( 50 | None, description="list of embedding table names that we want to log during training" 51 | ) 52 | 53 | 54 | class Mode(str, Enum): 55 | """Job modes.""" 56 | 57 | TRAIN = "train" 58 | EVALUATE = "evaluate" 59 | INFERENCE = "inference" 60 | -------------------------------------------------------------------------------- /projects/home/recap/config.py: -------------------------------------------------------------------------------- 1 | from tml.core import config as config_mod 2 | import tml.projects.home.recap.data.config as data_config 3 | import tml.projects.home.recap.model.config as model_config 4 | import tml.projects.home.recap.optimizer.config as optimizer_config 5 | 6 | from enum import Enum 7 | from typing import Dict, Optional 8 | import pydantic 9 | 10 | 11 | class TrainingConfig(config_mod.BaseConfig): 12 | save_dir: str = "/tmp/model" 13 | num_train_steps: pydantic.PositiveInt = 1000000 14 | initial_checkpoint_dir: str = pydantic.Field( 15 | None, description="Directory of initial checkpoints", at_most_one_of="initialization" 16 | ) 17 | checkpoint_every_n: pydantic.PositiveInt = 1000 18 | checkpoint_max_to_keep: pydantic.PositiveInt = pydantic.Field( 19 | None, description="Maximum number of checkpoints to keep. Defaults to keeping all." 20 | ) 21 | train_log_every_n: pydantic.PositiveInt = 1000 22 | num_eval_steps: int = pydantic.Field( 23 | 16384, description="Number of evaluation steps. If < 0 the entire dataset " "will be used." 24 | ) 25 | eval_log_every_n: pydantic.PositiveInt = 5000 26 | 27 | eval_timeout_in_s: pydantic.PositiveFloat = 60 * 60 28 | 29 | gradient_accumulation: int = pydantic.Field( 30 | None, description="Number of replica steps to accumulate gradients." 31 | ) 32 | 33 | 34 | class RecapConfig(config_mod.BaseConfig): 35 | training: TrainingConfig = pydantic.Field(TrainingConfig()) 36 | model: model_config.ModelConfig 37 | train_data: data_config.RecapDataConfig 38 | validation_data: Dict[str, data_config.RecapDataConfig] 39 | optimizer: optimizer_config.RecapOptimizerConfig 40 | 41 | which_metrics: Optional[str] = pydantic.Field(None, description="which metrics to pick.") 42 | 43 | # DANGER DANGER! You might expect validators here to ensure that multi task learning setups are 44 | # the same as the data. Unfortunately, this throws opaque errors when the model configuration is 45 | # invalid. In our judgement, that is a more frequency and worse occurrence than tasks not matching 46 | # the data. 47 | 48 | 49 | class JobMode(str, Enum): 50 | """Job modes.""" 51 | 52 | TRAIN = "train" 53 | EVALUATE = "evaluate" 54 | INFERENCE = "inference" 55 | -------------------------------------------------------------------------------- /reader/test_dataset.py: -------------------------------------------------------------------------------- 1 | import multiprocessing as mp 2 | import os 3 | from unittest.mock import patch 4 | 5 | import tml.reader.utils as reader_utils 6 | from tml.reader.dataset import Dataset 7 | 8 | import pyarrow as pa 9 | import pyarrow.parquet as pq 10 | import pytest 11 | import torch 12 | 13 | 14 | def create_dataset(tmpdir): 15 | 16 | table = pa.table( 17 | { 18 | "year": [2020, 2022, 2021, 2022, 2019, 2021], 19 | "n_legs": [2, 2, 4, 4, 5, 100], 20 | } 21 | ) 22 | file_path = tmpdir 23 | pq.write_to_dataset(table, root_path=str(file_path)) 24 | 25 | class MockDataset(Dataset): 26 | def __init__(self, *args, **kwargs): 27 | super().__init__(*args, **kwargs) 28 | self._pa_to_batch = reader_utils.create_default_pa_to_batch(self._schema) 29 | 30 | def pa_to_batch(self, batch): 31 | return self._pa_to_batch(batch) 32 | 33 | return MockDataset(file_pattern=str(file_path / "*"), batch_size=2) 34 | 35 | 36 | def test_dataset(tmpdir): 37 | ds = create_dataset(tmpdir) 38 | batch = next(iter(ds.dataloader(remote=False))) 39 | assert batch.batch_size == 2 40 | assert torch.equal(batch.year, torch.Tensor([2020, 2022])) 41 | assert torch.equal(batch.n_legs, torch.Tensor([2, 2])) 42 | 43 | 44 | @pytest.mark.skipif( 45 | os.environ.get("GITHUB_WORKSPACE") is not None, 46 | reason="Multiprocessing doesn't work on github yet.", 47 | ) 48 | def test_distributed_dataset(tmpdir): 49 | MOCK_ENV = {"TEMP_SLURM_NUM_READERS": "1"} 50 | 51 | def _client(): 52 | with patch.dict(os.environ, MOCK_ENV): 53 | with patch( 54 | "tml.reader.dataset.env.get_flight_server_addresses", return_value=["grpc://localhost:2222"] 55 | ): 56 | ds = create_dataset(tmpdir) 57 | batch = next(iter(ds.dataloader(remote=True))) 58 | assert batch.batch_size == 2 59 | assert torch.equal(batch.year, torch.Tensor([2020, 2022])) 60 | assert torch.equal(batch.n_legs, torch.Tensor([2, 2])) 61 | 62 | def _worker(): 63 | ds = create_dataset(tmpdir) 64 | ds.serve() 65 | 66 | worker = mp.Process(target=_worker) 67 | client = mp.Process(target=_client) 68 | worker.start() 69 | client.start() 70 | client.join() 71 | assert not client.exitcode 72 | worker.kill() 73 | client.kill() 74 | -------------------------------------------------------------------------------- /ml_logging/torch_logging.py: -------------------------------------------------------------------------------- 1 | """Overrides absl logger to be rank-aware for distributed pytorch usage. 2 | 3 | >>> # in-bazel import 4 | >>> from twitter.ml.logging.torch_logging import logging 5 | >>> # out-bazel import 6 | >>> from ml.logging.torch_logging import logging 7 | >>> logging.info(f"This only prints on rank 0 if distributed, otherwise prints normally.") 8 | >>> logging.info(f"This prints on all ranks if distributed, otherwise prints normally.", rank=-1) 9 | 10 | """ 11 | import functools 12 | from typing import Optional 13 | 14 | from tml.ml_logging.absl_logging import logging as logging 15 | from absl import logging as absl_logging 16 | 17 | import torch.distributed as dist 18 | 19 | 20 | def rank_specific(logger): 21 | """Ensures that we only override a given logger once.""" 22 | if hasattr(logger, "_ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC"): 23 | return logger 24 | 25 | def _if_rank(logger_method, limit: Optional[int] = None): 26 | if limit: 27 | # If we are limiting redundant logs, wrap logging call with a cache 28 | # to not execute if already cached. 29 | def _wrap(_call): 30 | @functools.lru_cache(limit) 31 | def _logger_method(*args, **kwargs): 32 | _call(*args, **kwargs) 33 | 34 | return _logger_method 35 | 36 | logger_method = _wrap(logger_method) 37 | 38 | def _inner(msg, *args, rank: int = 0, **kwargs): 39 | if not dist.is_initialized(): 40 | logger_method(msg, *args, **kwargs) 41 | elif dist.get_rank() == rank: 42 | logger_method(msg, *args, **kwargs) 43 | elif rank < 0: 44 | logger_method(f"Rank{dist.get_rank()}: {msg}", *args, **kwargs) 45 | 46 | # Register this stack frame with absl logging so that it doesn't trample logging lines. 47 | absl_logging.ABSLLogger.register_frame_to_skip(__file__, _inner.__name__) 48 | 49 | return _inner 50 | 51 | logger.fatal = _if_rank(logger.fatal) 52 | logger.error = _if_rank(logger.error) 53 | logger.warning = _if_rank(logger.warning, limit=1) 54 | logger.info = _if_rank(logger.info) 55 | logger.debug = _if_rank(logger.debug) 56 | logger.exception = _if_rank(logger.exception) 57 | 58 | logger._ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC = True 59 | 60 | 61 | rank_specific(logging) 62 | -------------------------------------------------------------------------------- /projects/twhin/optimizer.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | from tml.projects.twhin.models.config import TwhinModelConfig 4 | from tml.projects.twhin.models.models import TwhinModel 5 | from tml.optimizers.optimizer import get_optimizer_class, LRShim 6 | from tml.optimizers.config import get_optimizer_algorithm_config, LearningRate 7 | from tml.ml_logging.torch_logging import logging 8 | 9 | from torchrec.optim.optimizers import in_backward_optimizer_filter 10 | from torchrec.optim import keyed 11 | 12 | 13 | FUSED_OPT_KEY = "fused_opt" 14 | TRANSLATION_OPT_KEY = "operator_opt" 15 | 16 | 17 | def _lr_from_config(optimizer_config): 18 | if optimizer_config.learning_rate is not None: 19 | return optimizer_config.learning_rate 20 | else: 21 | # treat None as constant lr 22 | lr_value = get_optimizer_algorithm_config(optimizer_config).lr 23 | return LearningRate(constant=lr_value) 24 | 25 | 26 | def build_optimizer(model: TwhinModel, config: TwhinModelConfig): 27 | """Builds an optimizer for a Twhin model combining the embeddings optimizer with an optimizer for per-relation translations. 28 | 29 | Args: 30 | model: TwhinModel to build optimizer for. 31 | config: TwhinConfig for model. 32 | 33 | Returns: 34 | Optimizer for model. 35 | """ 36 | translation_optimizer_fn = functools.partial( 37 | get_optimizer_class(config.translation_optimizer), 38 | **get_optimizer_algorithm_config(config.translation_optimizer).dict(), 39 | ) 40 | 41 | translation_optimizer = keyed.KeyedOptimizerWrapper( 42 | dict(in_backward_optimizer_filter(model.named_parameters())), 43 | optim_factory=translation_optimizer_fn, 44 | ) 45 | 46 | lr_dict = {} 47 | for table in config.embeddings.tables: 48 | lr_dict[table.name] = _lr_from_config(table.optimizer) 49 | lr_dict[TRANSLATION_OPT_KEY] = _lr_from_config(config.translation_optimizer) 50 | 51 | logging.info(f"***** LR dict: {lr_dict} *****") 52 | 53 | logging.info( 54 | f"***** Combining fused optimizer {model.fused_optimizer} with operator optimizer: {translation_optimizer} *****" 55 | ) 56 | optimizer = keyed.CombinedOptimizer( 57 | [ 58 | (FUSED_OPT_KEY, model.fused_optimizer), 59 | (TRANSLATION_OPT_KEY, translation_optimizer), 60 | ] 61 | ) 62 | 63 | # scheduler = LRShim(optimizer, lr_dict) 64 | scheduler = None 65 | 66 | logging.info(f"***** Combined optimizer after init: {optimizer} *****") 67 | 68 | return optimizer, scheduler 69 | -------------------------------------------------------------------------------- /core/config/base_config.py: -------------------------------------------------------------------------------- 1 | """Base class for all config (forbids extra fields).""" 2 | 3 | import collections 4 | import functools 5 | import yaml 6 | 7 | import pydantic 8 | 9 | 10 | class BaseConfig(pydantic.BaseModel): 11 | """Base class for all derived config classes. 12 | 13 | This class provides some convenient functionality: 14 | - Disallows extra fields when constructing an object. User error 15 | should be reduced by exact arguments. 16 | - "one_of" fields. A subclass can group optional fields and enforce 17 | that only one of the fields be set. For example: 18 | 19 | ``` 20 | class ExampleConfig(BaseConfig): 21 | x: int = Field(None, one_of="group_1") 22 | y: int = Field(None, one_of="group_1") 23 | 24 | ExampleConfig(x=1) # ok 25 | ExampleConfig(y=1) # ok 26 | ExampleConfig(x=1, y=1) # throws error 27 | ``` 28 | """ 29 | 30 | class Config: 31 | """Forbids extras.""" 32 | 33 | extra = pydantic.Extra.forbid # noqa 34 | 35 | @classmethod 36 | @functools.lru_cache() 37 | def _field_data_map(cls, field_data_name): 38 | """Create a map of fields with provided the field data.""" 39 | schema = cls.schema() 40 | one_of = collections.defaultdict(list) 41 | for field, fdata in schema["properties"].items(): 42 | if field_data_name in fdata: 43 | one_of[fdata[field_data_name]].append(field) 44 | return one_of 45 | 46 | @pydantic.root_validator 47 | def _one_of_check(cls, values): 48 | """Validate that all 'one of' fields are appear exactly once.""" 49 | one_of_map = cls._field_data_map("one_of") 50 | for one_of, field_names in one_of_map.items(): 51 | if sum([values.get(n, None) is not None for n in field_names]) != 1: 52 | raise ValueError(f"Exactly one of {','.join(field_names)} required.") 53 | return values 54 | 55 | @pydantic.root_validator 56 | def _at_most_one_of_check(cls, values): 57 | """Validate that all 'at_most_one_of' fields appear at most once.""" 58 | at_most_one_of_map = cls._field_data_map("at_most_one_of") 59 | for one_of, field_names in at_most_one_of_map.items(): 60 | if sum([values.get(n, None) is not None for n in field_names]) > 1: 61 | raise ValueError(f"At most one of {','.join(field_names)} can be set.") 62 | return values 63 | 64 | def pretty_print(self) -> str: 65 | """Return a human legible (yaml) representation of the config useful for logging.""" 66 | return yaml.dump(self.dict()) 67 | -------------------------------------------------------------------------------- /projects/home/recap/model/model_and_loss.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional, List 2 | from tml.projects.home.recap.embedding import config as embedding_config_mod 3 | import torch 4 | from absl import logging 5 | 6 | 7 | class ModelAndLoss(torch.nn.Module): 8 | def __init__( 9 | self, 10 | model, 11 | loss_fn: Callable, 12 | stratifiers: Optional[List[embedding_config_mod.StratifierConfig]] = None, 13 | ) -> None: 14 | """ 15 | Args: 16 | model: torch module to wrap. 17 | loss_fn: Function for calculating loss, should accept logits and labels. 18 | straitifiers: mapping of stratifier name and index of discrete features to emit for metrics stratification. 19 | """ 20 | super().__init__() 21 | self.model = model 22 | self.loss_fn = loss_fn 23 | self.stratifiers = stratifiers 24 | 25 | def forward(self, batch: "RecapBatch"): # type: ignore[name-defined] 26 | """Runs model forward and calculates loss according to given loss_fn. 27 | 28 | NOTE: The input signature here needs to be a Pipelineable object for 29 | prefetching purposes during training using torchrec's pipeline. However 30 | the underlying model signature needs to be exportable to onnx, requiring 31 | generic python types. see https://pytorch.org/docs/stable/onnx.html#types. 32 | 33 | """ 34 | outputs = self.model( 35 | continuous_features=batch.continuous_features, 36 | binary_features=batch.binary_features, 37 | discrete_features=batch.discrete_features, 38 | sparse_features=batch.sparse_features, 39 | user_embedding=batch.user_embedding, 40 | user_eng_embedding=batch.user_eng_embedding, 41 | author_embedding=batch.author_embedding, 42 | labels=batch.labels, 43 | weights=batch.weights, 44 | ) 45 | losses = self.loss_fn(outputs["logits"], batch.labels.float(), batch.weights.float()) 46 | 47 | if self.stratifiers: 48 | logging.info(f"***** Adding stratifiers *****\n {self.stratifiers}") 49 | outputs["stratifiers"] = {} 50 | for stratifier in self.stratifiers: 51 | outputs["stratifiers"][stratifier.name] = batch.discrete_features[:, stratifier.index] 52 | 53 | # In general, we can have a large number of losses returned by our loss function. 54 | if isinstance(losses, dict): 55 | return losses["loss"], { 56 | **outputs, 57 | **losses, 58 | "labels": batch.labels, 59 | "weights": batch.weights, 60 | } 61 | else: # Assume that this is a float. 62 | return losses, { 63 | **outputs, 64 | "loss": losses, 65 | "labels": batch.labels, 66 | "weights": batch.weights, 67 | } 68 | -------------------------------------------------------------------------------- /optimizers/config.py: -------------------------------------------------------------------------------- 1 | """Optimization configurations for models.""" 2 | 3 | import typing 4 | 5 | import tml.core.config as base_config 6 | 7 | import pydantic 8 | 9 | 10 | class PiecewiseConstant(base_config.BaseConfig): 11 | learning_rate_boundaries: typing.List[int] = pydantic.Field(None) 12 | learning_rate_values: typing.List[float] = pydantic.Field(None) 13 | 14 | 15 | class LinearRampToConstant(base_config.BaseConfig): 16 | learning_rate: float 17 | num_ramp_steps: pydantic.PositiveInt = pydantic.Field( 18 | description="Number of steps to ramp this up from zero." 19 | ) 20 | 21 | 22 | class LinearRampToCosine(base_config.BaseConfig): 23 | learning_rate: float 24 | final_learning_rate: float 25 | num_ramp_steps: pydantic.PositiveInt = pydantic.Field( 26 | description="Number of steps to ramp this up from zero." 27 | ) 28 | final_num_steps: pydantic.PositiveInt = pydantic.Field( 29 | description="Final number of steps where decay stops." 30 | ) 31 | 32 | 33 | class LearningRate(base_config.BaseConfig): 34 | constant: float = pydantic.Field(None, one_of="lr") 35 | linear_ramp_to_cosine: LinearRampToCosine = pydantic.Field(None, one_of="lr") 36 | linear_ramp_to_constant: LinearRampToConstant = pydantic.Field(None, one_of="lr") 37 | piecewise_constant: PiecewiseConstant = pydantic.Field(None, one_of="lr") 38 | 39 | 40 | class OptimizerAlgorithmConfig(base_config.BaseConfig): 41 | """Base class for optimizer configurations.""" 42 | 43 | lr: float 44 | ... 45 | 46 | 47 | class AdamConfig(OptimizerAlgorithmConfig): 48 | # see https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam 49 | lr: float 50 | betas: typing.Tuple[float, float] = [0.9, 0.999] 51 | eps: float = 1e-7 # Numerical stability in denominator. 52 | 53 | 54 | class SgdConfig(OptimizerAlgorithmConfig): 55 | lr: float 56 | momentum: float = 0.0 57 | 58 | 59 | class AdagradConfig(OptimizerAlgorithmConfig): 60 | lr: float 61 | eps: float = 0 62 | 63 | 64 | class OptimizerConfig(base_config.BaseConfig): 65 | learning_rate: LearningRate = pydantic.Field( 66 | None, 67 | description="Constant learning rates", 68 | ) 69 | adam: AdamConfig = pydantic.Field(None, one_of="optimizer") 70 | sgd: SgdConfig = pydantic.Field(None, one_of="optimizer") 71 | adagrad: AdagradConfig = pydantic.Field(None, one_of="optimizer") 72 | 73 | 74 | def get_optimizer_algorithm_config(optimizer_config: OptimizerConfig): 75 | if optimizer_config.adam is not None: 76 | return optimizer_config.adam 77 | elif optimizer_config.sgd is not None: 78 | return optimizer_config.sgd 79 | elif optimizer_config.adagrad is not None: 80 | return optimizer_config.adagrad 81 | else: 82 | raise ValueError(f"No optimizer selected in optimizer_config, passed {optimizer_config}") 83 | -------------------------------------------------------------------------------- /common/batch.py: -------------------------------------------------------------------------------- 1 | """Extension of torchrec.dataset.utils.Batch to cover any dataset. 2 | """ 3 | # flake8: noqa 4 | from __future__ import annotations 5 | from typing import Dict 6 | import abc 7 | from dataclasses import dataclass 8 | import dataclasses 9 | 10 | import torch 11 | from torchrec.streamable import Pipelineable 12 | 13 | 14 | class BatchBase(Pipelineable, abc.ABC): 15 | @abc.abstractmethod 16 | def as_dict(self) -> Dict: 17 | raise NotImplementedError 18 | 19 | def to(self, device: torch.device, non_blocking: bool = False): 20 | args = {} 21 | for feature_name, feature_value in self.as_dict().items(): 22 | args[feature_name] = feature_value.to(device=device, non_blocking=non_blocking) 23 | return self.__class__(**args) 24 | 25 | def record_stream(self, stream: torch.cuda.streams.Stream) -> None: 26 | for feature_value in self.as_dict().values(): 27 | feature_value.record_stream(stream) 28 | 29 | def pin_memory(self): 30 | args = {} 31 | for feature_name, feature_value in self.as_dict().items(): 32 | args[feature_name] = feature_value.pin_memory() 33 | return self.__class__(**args) 34 | 35 | def __repr__(self) -> str: 36 | def obj2str(v): 37 | return f"{v.size()}" if hasattr(v, "size") else f"{v.length_per_key()}" 38 | 39 | return "\n".join([f"{k}: {obj2str(v)}," for k, v in self.as_dict().items()]) 40 | 41 | @property 42 | def batch_size(self) -> int: 43 | for tensor in self.as_dict().values(): 44 | if tensor is None: 45 | continue 46 | if not isinstance(tensor, torch.Tensor): 47 | continue 48 | return tensor.shape[0] 49 | raise Exception("Could not determine batch size from tensors.") 50 | 51 | 52 | @dataclass 53 | class DataclassBatch(BatchBase): 54 | @classmethod 55 | def feature_names(cls): 56 | return list(cls.__dataclass_fields__.keys()) 57 | 58 | def as_dict(self): 59 | return { 60 | feature_name: getattr(self, feature_name) 61 | for feature_name in self.feature_names() 62 | if hasattr(self, feature_name) 63 | } 64 | 65 | @staticmethod 66 | def from_schema(name: str, schema): 67 | """Instantiates a custom batch subclass if all columns can be represented as a torch.Tensor.""" 68 | return dataclasses.make_dataclass( 69 | cls_name=name, 70 | fields=[(name, torch.Tensor, dataclasses.field(default=None)) for name in schema.names], 71 | bases=(DataclassBatch,), 72 | ) 73 | 74 | @staticmethod 75 | def from_fields(name: str, fields: dict): 76 | return dataclasses.make_dataclass( 77 | cls_name=name, 78 | fields=[(_name, _type, dataclasses.field(default=None)) for _name, _type in fields.items()], 79 | bases=(DataclassBatch,), 80 | ) 81 | 82 | 83 | class DictionaryBatch(BatchBase, dict): 84 | def as_dict(self) -> Dict: 85 | return self 86 | -------------------------------------------------------------------------------- /reader/utils.py: -------------------------------------------------------------------------------- 1 | """Reader utilities.""" 2 | import itertools 3 | import time 4 | from typing import Optional 5 | 6 | from tml.common.batch import DataclassBatch 7 | from tml.ml_logging.torch_logging import logging 8 | 9 | import pyarrow as pa 10 | import torch 11 | 12 | 13 | def roundrobin(*iterables): 14 | """Round robin through provided iterables, useful for simple load balancing. 15 | 16 | Adapted from https://docs.python.org/3/library/itertools.html. 17 | 18 | """ 19 | num_active = len(iterables) 20 | nexts = itertools.cycle(iter(it).__next__ for it in iterables) 21 | while num_active: 22 | try: 23 | for _next in nexts: 24 | result = _next() 25 | yield result 26 | except StopIteration: 27 | # Remove the iterator we just exhausted from the cycle. 28 | num_active -= 1 29 | nexts = itertools.cycle(itertools.islice(nexts, num_active)) 30 | logging.warning(f"Iterable exhausted, {num_active} iterables left.") 31 | except Exception as exc: 32 | logging.warning(f"Iterable raised exception {exc}, ignoring.") 33 | # continue 34 | raise 35 | 36 | 37 | def speed_check(data_loader, max_steps: int, frequency: int, peek: Optional[int]): 38 | num_examples = 0 39 | prev = time.perf_counter() 40 | for idx, batch in enumerate(data_loader): 41 | if idx > max_steps: 42 | break 43 | if peek and idx % peek == 0: 44 | logging.info(f"Batch: {batch}") 45 | num_examples += batch.batch_size 46 | if idx % frequency == 0: 47 | now = time.perf_counter() 48 | elapsed = now - prev 49 | logging.info( 50 | f"step: {idx}, " 51 | f"elapsed(s): {elapsed}, " 52 | f"examples: {num_examples}, " 53 | f"ex/s: {num_examples / elapsed}, " 54 | ) 55 | prev = now 56 | num_examples = 0 57 | 58 | 59 | def pa_to_torch(array: pa.array) -> torch.Tensor: 60 | return torch.from_numpy(array.to_numpy()) 61 | 62 | 63 | def create_default_pa_to_batch(schema) -> DataclassBatch: 64 | """ """ 65 | _CustomBatch = DataclassBatch.from_schema("DefaultBatch", schema=schema) 66 | 67 | def get_imputation_value(pa_type): 68 | type_map = { 69 | pa.float64(): pa.scalar(0, type=pa.float64()), 70 | pa.int64(): pa.scalar(0, type=pa.int64()), 71 | pa.string(): pa.scalar("", type=pa.string()), 72 | } 73 | if pa_type not in type_map: 74 | raise Exception(f"Imputation for type {pa_type} not supported.") 75 | return type_map[pa_type] 76 | 77 | def _impute(array: pa.array) -> pa.array: 78 | return array.fill_null(get_imputation_value(array.type)) 79 | 80 | def _column_to_tensor(record_batch: pa.RecordBatch): 81 | tensors = { 82 | col_name: pa_to_torch(_impute(record_batch.column(col_name))) 83 | for col_name in record_batch.schema.names 84 | } 85 | return _CustomBatch(**tensors) 86 | 87 | return _column_to_tensor 88 | -------------------------------------------------------------------------------- /projects/home/recap/data/generate_random_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from absl import app, flags, logging 4 | import tensorflow as tf 5 | from typing import Dict 6 | 7 | from tml.projects.home.recap.data import tfe_parsing 8 | from tml.core import config as tml_config_mod 9 | import tml.projects.home.recap.config as recap_config_mod 10 | 11 | flags.DEFINE_string("config_path", None, "Path to hyperparameters for model.") 12 | flags.DEFINE_integer("n_examples", 100, "Numer of examples to generate.") 13 | 14 | FLAGS = flags.FLAGS 15 | 16 | 17 | def _generate_random_example( 18 | tf_example_schema: Dict[str, tf.io.FixedLenFeature] 19 | ) -> Dict[str, tf.Tensor]: 20 | example = {} 21 | for feature_name, feature_spec in tf_example_schema.items(): 22 | dtype = feature_spec.dtype 23 | if (dtype == tf.int64) or (dtype == tf.int32): 24 | x = tf.experimental.numpy.random.randint(0, high=10, size=feature_spec.shape, dtype=dtype) 25 | elif (dtype == tf.float32) or (dtype == tf.float64): 26 | x = tf.random.uniform(shape=[feature_spec.shape], dtype=dtype) 27 | else: 28 | raise NotImplementedError(f"Unknown type {dtype}") 29 | 30 | example[feature_name] = x 31 | 32 | return example 33 | 34 | 35 | def _float_feature(value): 36 | return tf.train.Feature(float_list=tf.train.FloatList(value=value)) 37 | 38 | 39 | def _int64_feature(value): 40 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 41 | 42 | 43 | def _serialize_example(x: Dict[str, tf.Tensor]) -> bytes: 44 | feature = {} 45 | serializers = {tf.float32: _float_feature, tf.int64: _int64_feature} 46 | for feature_name, tensor in x.items(): 47 | feature[feature_name] = serializers[tensor.dtype](tensor) 48 | 49 | example_proto = tf.train.Example(features=tf.train.Features(feature=feature)) 50 | return example_proto.SerializeToString() 51 | 52 | 53 | def generate_data(data_path: str, config: recap_config_mod.RecapConfig): 54 | with tf.io.gfile.GFile(config.train_data.seg_dense_schema.schema_path, "r") as f: 55 | seg_dense_schema = json.load(f)["schema"] 56 | 57 | tf_example_schema = tfe_parsing.create_tf_example_schema( 58 | config.train_data, 59 | seg_dense_schema, 60 | ) 61 | 62 | record_filename = os.path.join(data_path, "random.tfrecord.gz") 63 | 64 | with tf.io.TFRecordWriter(record_filename, "GZIP") as writer: 65 | random_example = _generate_random_example(tf_example_schema) 66 | serialized_example = _serialize_example(random_example) 67 | writer.write(serialized_example) 68 | 69 | 70 | def _generate_data_main(unused_argv): 71 | config = tml_config_mod.load_config_from_yaml(recap_config_mod.RecapConfig, FLAGS.config_path) 72 | 73 | # Find the path where to put the data 74 | data_path = os.path.dirname(config.train_data.inputs) 75 | logging.info("Putting random data in %s", data_path) 76 | 77 | generate_data(data_path, config) 78 | 79 | 80 | if __name__ == "__main__": 81 | app.run(_generate_data_main) 82 | -------------------------------------------------------------------------------- /core/test_train_pipeline.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Tuple 3 | 4 | from tml.common.batch import DataclassBatch 5 | from tml.common.testing_utils import mock_pg 6 | from tml.core import train_pipeline 7 | 8 | import torch 9 | from torchrec.distributed import DistributedModelParallel 10 | 11 | 12 | @dataclass 13 | class MockDataclassBatch(DataclassBatch): 14 | continuous_features: torch.Tensor 15 | labels: torch.Tensor 16 | 17 | 18 | class MockModule(torch.nn.Module): 19 | def __init__(self) -> None: 20 | super().__init__() 21 | self.model = torch.nn.Linear(10, 1) 22 | self.loss_fn = torch.nn.BCEWithLogitsLoss() 23 | 24 | def forward(self, batch: MockDataclassBatch) -> Tuple[torch.Tensor, torch.Tensor]: 25 | pred = self.model(batch.continuous_features) 26 | loss = self.loss_fn(pred, batch.labels) 27 | return (loss, pred) 28 | 29 | 30 | def create_batch(bsz: int): 31 | return MockDataclassBatch( 32 | continuous_features=torch.rand(bsz, 10).float(), 33 | labels=torch.bernoulli(torch.empty(bsz, 1).uniform_(0, 1)).float(), 34 | ) 35 | 36 | 37 | def test_sparse_pipeline(): 38 | device = torch.device("cpu") 39 | model = MockModule().to(device) 40 | 41 | steps = 8 42 | example = create_batch(1) 43 | dataloader = iter(example for _ in range(steps + 2)) 44 | 45 | results = [] 46 | with mock_pg(): 47 | d_model = DistributedModelParallel(model) 48 | pipeline = train_pipeline.TrainPipelineSparseDist( 49 | model=d_model, 50 | optimizer=torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9), 51 | device=device, 52 | grad_accum=2, 53 | ) 54 | for _ in range(steps): 55 | results.append(pipeline.progress(dataloader)) 56 | 57 | results = [elem.detach().numpy() for elem in results] 58 | # Check gradients are accumulated, i.e. results do not change for every 0th and 1th. 59 | for first, second in zip(results[::2], results[1::2]): 60 | assert first == second, results 61 | 62 | # Check we do update gradients, i.e. results do change for every 1th and 2nd. 63 | for first, second in zip(results[1::2], results[2::2]): 64 | assert first != second, results 65 | 66 | 67 | def test_amp(): 68 | device = torch.device("cpu") 69 | model = MockModule().to(device) 70 | 71 | steps = 8 72 | example = create_batch(1) 73 | dataloader = iter(example for _ in range(steps + 2)) 74 | 75 | results = [] 76 | with mock_pg(): 77 | d_model = DistributedModelParallel(model) 78 | pipeline = train_pipeline.TrainPipelineSparseDist( 79 | model=d_model, 80 | optimizer=torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9), 81 | device=device, 82 | enable_amp=True, 83 | # Not supported on CPU. 84 | enable_grad_scaling=False, 85 | ) 86 | for _ in range(steps): 87 | results.append(pipeline.progress(dataloader)) 88 | 89 | results = [elem.detach() for elem in results] 90 | for value in results: 91 | assert value.dtype == torch.bfloat16 92 | -------------------------------------------------------------------------------- /machines/environment.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import List 4 | 5 | 6 | KF_DDS_PORT: int = 5050 7 | SLURM_DDS_PORT: int = 5051 8 | FLIGHT_SERVER_PORT: int = 2222 9 | 10 | 11 | def on_kf(): 12 | return "SPEC_TYPE" in os.environ 13 | 14 | 15 | def has_readers(): 16 | if on_kf(): 17 | machines_config_env = json.loads(os.environ["MACHINES_CONFIG"]) 18 | return machines_config_env["dataset_worker"] is not None 19 | return os.environ.get("HAS_READERS", "False") == "True" 20 | 21 | 22 | def get_task_type(): 23 | if on_kf(): 24 | return os.environ["SPEC_TYPE"] 25 | return os.environ["TASK_TYPE"] 26 | 27 | 28 | def is_chief() -> bool: 29 | return get_task_type() == "chief" 30 | 31 | 32 | def is_reader() -> bool: 33 | return get_task_type() == "datasetworker" 34 | 35 | 36 | def is_dispatcher() -> bool: 37 | return get_task_type() == "datasetdispatcher" 38 | 39 | 40 | def get_task_index(): 41 | if on_kf(): 42 | pod_name = os.environ["MY_POD_NAME"] 43 | return int(pod_name.split("-")[-1]) 44 | else: 45 | raise NotImplementedError 46 | 47 | 48 | def get_reader_port(): 49 | if on_kf(): 50 | return KF_DDS_PORT 51 | return SLURM_DDS_PORT 52 | 53 | 54 | def get_dds(): 55 | if not has_readers(): 56 | return None 57 | dispatcher_address = get_dds_dispatcher_address() 58 | if dispatcher_address: 59 | return f"grpc://{dispatcher_address}" 60 | else: 61 | raise ValueError("Job does not have DDS.") 62 | 63 | 64 | def get_dds_dispatcher_address(): 65 | if not has_readers(): 66 | return None 67 | if on_kf(): 68 | job_name = os.environ["JOB_NAME"] 69 | dds_host = f"{job_name}-datasetdispatcher-0" 70 | else: 71 | dds_host = os.environ["SLURM_JOB_NODELIST_HET_GROUP_0"] 72 | return f"{dds_host}:{get_reader_port()}" 73 | 74 | 75 | def get_dds_worker_address(): 76 | if not has_readers(): 77 | return None 78 | if on_kf(): 79 | job_name = os.environ["JOB_NAME"] 80 | task_index = get_task_index() 81 | return f"{job_name}-datasetworker-{task_index}:{get_reader_port()}" 82 | else: 83 | node = os.environ["SLURMD_NODENAME"] 84 | return f"{node}:{get_reader_port()}" 85 | 86 | 87 | def get_num_readers(): 88 | if not has_readers(): 89 | return 0 90 | if on_kf(): 91 | machines_config_env = json.loads(os.environ["MACHINES_CONFIG"]) 92 | return int(machines_config_env["num_dataset_workers"] or 0) 93 | return len(os.environ["SLURM_JOB_NODELIST_HET_GROUP_1"].split(",")) 94 | 95 | 96 | def get_flight_server_addresses(): 97 | if on_kf(): 98 | job_name = os.environ["JOB_NAME"] 99 | return [ 100 | f"grpc://{job_name}-datasetworker-{task_index}:{FLIGHT_SERVER_PORT}" 101 | for task_index in range(get_num_readers()) 102 | ] 103 | else: 104 | raise NotImplementedError 105 | 106 | 107 | def get_dds_journaling_dir(): 108 | return os.environ.get("DATASET_JOURNALING_DIR", None) 109 | -------------------------------------------------------------------------------- /tools/pq.py: -------------------------------------------------------------------------------- 1 | """Local reader of parquet files. 2 | 3 | 1. Make sure you are initialized locally: 4 | ``` 5 | ./images/init_venv_macos.sh 6 | ``` 7 | 2. Activate 8 | ``` 9 | source ~/tml_venv/bin/activate 10 | ``` 11 | 3. Use tool, e.g. 12 | 13 | `head` prints the first `--num` rows of the dataset. 14 | ``` 15 | python3 tools/pq.py \ 16 | --num 5 --path "tweet_eng/small/edges/all/*" \ 17 | head 18 | ``` 19 | 20 | `distinct` prints the observed values in the first `--num` rows for the specified columns. 21 | ``` 22 | python3 tools/pq.py \ 23 | --num 1000000000 --columns '["rel"]' \ 24 | --path "tweet_eng/small/edges/all/*" \ 25 | distinct 26 | ``` 27 | 28 | """ 29 | from typing import List, Optional 30 | 31 | from tml.common.filesystem import infer_fs 32 | 33 | import fire 34 | import pandas as pd 35 | import pyarrow as pa 36 | import pyarrow.dataset as pads 37 | import pyarrow.parquet as pq 38 | 39 | 40 | def _create_dataset(path: str): 41 | fs = infer_fs(path) 42 | files = fs.glob(path) 43 | return pads.dataset(files, format="parquet", filesystem=fs) 44 | 45 | 46 | class PqReader: 47 | def __init__( 48 | self, path: str, num: int = 10, batch_size: int = 1024, columns: Optional[List[str]] = None 49 | ): 50 | self._ds = _create_dataset(path) 51 | self._batch_size = batch_size 52 | self._num = num 53 | self._columns = columns 54 | 55 | def __iter__(self): 56 | batches = self._ds.to_batches(batch_size=self._batch_size, columns=self._columns) 57 | rows_seen = 0 58 | for count, record in enumerate(batches): 59 | if self._num and rows_seen >= self._num: 60 | break 61 | yield record 62 | rows_seen += record.data.num_rows 63 | 64 | def _head(self): 65 | total_read = self._num * self.bytes_per_row 66 | if total_read >= int(500e6): 67 | raise Exception( 68 | "Sorry you're trying to read more than 500 MB " f"into memory ({total_read} bytes)." 69 | ) 70 | return self._ds.head(self._num, columns=self._columns) 71 | 72 | @property 73 | def bytes_per_row(self) -> int: 74 | nbits = 0 75 | for t in self._ds.schema.types: 76 | try: 77 | nbits += t.bit_width 78 | except: 79 | # Just estimate size if it is variable 80 | nbits += 8 81 | return nbits // 8 82 | 83 | def schema(self): 84 | print(f"\n# Schema\n{self._ds.schema}") 85 | 86 | def head(self): 87 | """Displays first --num rows.""" 88 | print(self._head().to_pandas()) 89 | 90 | def distinct(self): 91 | """Displays unique values seen in specified columns in the first `--num` rows. 92 | 93 | Useful for getting an approximate vocabulary for certain columns. 94 | 95 | """ 96 | for col_name, column in zip(self._head().column_names, self._head().columns): 97 | print(col_name) 98 | print("unique:", column.unique().to_pylist()) 99 | 100 | 101 | if __name__ == "__main__": 102 | pd.set_option("display.max_columns", None) 103 | pd.set_option("display.max_rows", None) 104 | fire.Fire(PqReader) 105 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | """Wraps servable model in loss and RecapBatch passing to be trainable.""" 2 | # flake8: noqa 3 | from typing import Callable 4 | 5 | from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined] 6 | 7 | import torch 8 | import torch.distributed as dist 9 | from torchrec.distributed.model_parallel import DistributedModelParallel 10 | 11 | 12 | class ModelAndLoss(torch.nn.Module): 13 | # Reconsider our approach at a later date: https://ppwwyyxx.com/blog/2022/Loss-Function-Separation/ 14 | 15 | def __init__( 16 | self, 17 | model, 18 | loss_fn: Callable, 19 | ) -> None: 20 | """ 21 | Args: 22 | model: torch module to wrap. 23 | loss_fn: Function for calculating loss, should accept logits and labels. 24 | """ 25 | super().__init__() 26 | self.model = model 27 | self.loss_fn = loss_fn 28 | 29 | def forward(self, batch: "RecapBatch"): # type: ignore[name-defined] 30 | """Runs model forward and calculates loss according to given loss_fn. 31 | 32 | NOTE: The input signature here needs to be a Pipelineable object for 33 | prefetching purposes during training using torchrec's pipeline. However 34 | the underlying model signature needs to be exportable to onnx, requiring 35 | generic python types. see https://pytorch.org/docs/stable/onnx.html#types. 36 | 37 | """ 38 | outputs = self.model(batch) 39 | losses = self.loss_fn(outputs["logits"], batch.labels.float(), batch.weights.float()) 40 | 41 | outputs.update( 42 | { 43 | "loss": losses, 44 | "labels": batch.labels, 45 | "weights": batch.weights, 46 | } 47 | ) 48 | 49 | # Allow multiple losses. 50 | return losses, outputs 51 | 52 | 53 | def maybe_shard_model( 54 | model, 55 | device: torch.device, 56 | ): 57 | """Set up and apply DistributedModelParallel to a model if running in a distributed environment. 58 | 59 | If in a distributed environment, constructs Topology, sharders, and ShardingPlan, then applies 60 | DistributedModelParallel. 61 | 62 | If not in a distributed environment, returns model directly. 63 | """ 64 | if dist.is_initialized(): 65 | logging.info("***** Wrapping in DistributedModelParallel *****") 66 | logging.info(f"Model before wrapping: {model}") 67 | model = DistributedModelParallel( 68 | module=model, 69 | device=device, 70 | ) 71 | logging.info(f"Model after wrapping: {model}") 72 | 73 | return model 74 | 75 | 76 | def log_sharded_tensor_content(weight_name: str, table_name: str, weight_tensor) -> None: 77 | """Handy function to log the content of EBC embedding layer. 78 | Only works for single GPU machines. 79 | 80 | Args: 81 | weight_name: name of tensor, as defined in model 82 | table_name: name of the EBC table the weight is taken from 83 | weight_tensor: embedding weight tensor 84 | """ 85 | logging.info(f"{weight_name}, {table_name}", rank=-1) 86 | logging.info(f"{weight_tensor.metadata()}", rank=-1) 87 | output_tensor = torch.zeros(*weight_tensor.size(), device=torch.device("cuda:0")) 88 | weight_tensor.gather(out=output_tensor) 89 | logging.info(f"{output_tensor}", rank=-1) 90 | -------------------------------------------------------------------------------- /projects/home/recap/config/home_recap_2022/segdense.json: -------------------------------------------------------------------------------- 1 | { 2 | "schema": [ 3 | { 4 | "dtype": "int64_list", 5 | "feature_name": "home_recap_2022_discrete__segdense_vals", 6 | "length": 320 7 | }, 8 | { 9 | "dtype": "float_list", 10 | "feature_name": "home_recap_2022_cont__segdense_vals", 11 | "length": 6000 12 | }, 13 | { 14 | "dtype": "int64_list", 15 | "feature_name": "home_recap_2022_binary__segdense_vals", 16 | "length": 512 17 | }, 18 | { 19 | "dtype": "int64_list", 20 | "feature_name": "recap.engagement.is_tweet_detail_dwelled_15_sec", 21 | "length": 1 22 | }, 23 | { 24 | "dtype": "int64_list", 25 | "feature_name": "recap.engagement.is_profile_clicked_and_profile_engaged", 26 | "length": 1 27 | }, 28 | { 29 | "dtype": "int64_list", 30 | "feature_name": "recap.engagement.is_replied_reply_engaged_by_author", 31 | "length": 1 32 | }, 33 | { 34 | "dtype": "int64_list", 35 | "feature_name": "recap.engagement.is_video_playback_50", 36 | "length": 1 37 | }, 38 | { 39 | "dtype": "int64_list", 40 | "feature_name": "recap.engagement.is_report_tweet_clicked", 41 | "length": 1 42 | }, 43 | { 44 | "dtype": "int64_list", 45 | "feature_name": "recap.engagement.is_replied", 46 | "length": 1 47 | }, 48 | { 49 | "dtype": "int64_list", 50 | "feature_name": "meta.author_id", 51 | "length": 1 52 | }, 53 | { 54 | "dtype": "int64_list", 55 | "feature_name": "recap.engagement.is_negative_feedback_v2", 56 | "length": 1 57 | }, 58 | { 59 | "dtype": "int64_list", 60 | "feature_name": "recap.engagement.is_retweeted", 61 | "length": 1 62 | }, 63 | { 64 | "dtype": "int64_list", 65 | "feature_name": "recap.engagement.is_favorited", 66 | "length": 1 67 | }, 68 | { 69 | "dtype": "int64_list", 70 | "feature_name": "recap.engagement.is_good_clicked_convo_desc_favorited_or_replied", 71 | "length": 1 72 | }, 73 | { 74 | "dtype": "int64_list", 75 | "feature_name": "meta.tweet_id", 76 | "length": 1 77 | }, 78 | { 79 | "dtype": "int64_list", 80 | "feature_name": "recap.engagement.is_good_clicked_convo_desc_v2", 81 | "length": 1 82 | }, 83 | { 84 | "dtype": "int64_list", 85 | "feature_name": "meta.user_id", 86 | "length": 1 87 | }, 88 | { 89 | "dtype": "int64_list", 90 | "feature_name": "recap.engagement.is_bookmarked", 91 | "length": 1 92 | }, 93 | { 94 | "dtype": "int64_list", 95 | "feature_name": "recap.engagement.is_shared", 96 | "length": 1 97 | }, 98 | { 99 | "dtype": "float_list", 100 | "feature_name": "user.timelines.twhin_user_engagement_embeddings.twhin_user_engagement_embeddings", 101 | "length": 200 102 | }, 103 | { 104 | "dtype": "float_list", 105 | "feature_name": "original_author.timelines.twhin_author_follow_embeddings.twhin_author_follow_embeddings", 106 | "length": 200 107 | }, 108 | { 109 | "dtype": "float_list", 110 | "feature_name": "user.timelines.twhin_user_follow_embeddings.twhin_user_follow_embeddings", 111 | "length": 200 112 | } 113 | ] 114 | } -------------------------------------------------------------------------------- /projects/twhin/models/test_models.py: -------------------------------------------------------------------------------- 1 | from tml.projects.twhin.models.config import TwhinEmbeddingsConfig, TwhinModelConfig 2 | from tml.projects.twhin.data.config import TwhinDataConfig 3 | from tml.common.modules.embedding.config import DataType, EmbeddingBagConfig 4 | from tml.optimizers.config import OptimizerConfig, SgdConfig 5 | from tml.model import maybe_shard_model 6 | from tml.projects.twhin.models.models import apply_optimizers, TwhinModel 7 | from tml.projects.twhin.models.config import Operator, Relation 8 | from tml.common.testing_utils import mock_pg 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | from pydantic import ValidationError 13 | import pytest 14 | 15 | 16 | NUM_EMBS = 10_000 17 | EMB_DIM = 128 18 | 19 | 20 | def twhin_model_config() -> TwhinModelConfig: 21 | sgd_config_0 = OptimizerConfig(sgd=SgdConfig(lr=0.01)) 22 | sgd_config_1 = OptimizerConfig(sgd=SgdConfig(lr=0.02)) 23 | 24 | table0 = EmbeddingBagConfig( 25 | name="table0", 26 | num_embeddings=NUM_EMBS, 27 | embedding_dim=EMB_DIM, 28 | optimizer=sgd_config_0, 29 | data_type=DataType.FP32, 30 | ) 31 | table1 = EmbeddingBagConfig( 32 | name="table1", 33 | num_embeddings=NUM_EMBS, 34 | embedding_dim=EMB_DIM, 35 | optimizer=sgd_config_1, 36 | data_type=DataType.FP32, 37 | ) 38 | embeddings_config = TwhinEmbeddingsConfig( 39 | tables=[table0, table1], 40 | ) 41 | 42 | model_config = TwhinModelConfig( 43 | embeddings=embeddings_config, 44 | translation_optimizer=sgd_config_0, 45 | relations=[ 46 | Relation(name="rel0", lhs="table0", rhs="table1", operator=Operator.TRANSLATION), 47 | Relation(name="rel1", lhs="table1", rhs="table0", operator=Operator.TRANSLATION), 48 | ], 49 | ) 50 | 51 | return model_config 52 | 53 | 54 | def twhin_data_config() -> TwhinDataConfig: 55 | data_config = TwhinDataConfig( 56 | data_root="/", 57 | per_replica_batch_size=10, 58 | global_negatives=10, 59 | in_batch_negatives=10, 60 | limit=1, 61 | offset=1, 62 | ) 63 | 64 | return data_config 65 | 66 | 67 | def test_twhin_model(): 68 | model_config = twhin_model_config() 69 | loss_fn = F.binary_cross_entropy_with_logits 70 | 71 | with mock_pg(): 72 | data_config = twhin_data_config() 73 | model = TwhinModel(model_config=model_config, data_config=data_config) 74 | 75 | apply_optimizers(model, model_config) 76 | 77 | for tensor in model.state_dict().values(): 78 | if tensor.size() == (NUM_EMBS, EMB_DIM): 79 | assert str(tensor.device) == "meta" 80 | else: 81 | assert str(tensor.device) == "cpu" 82 | 83 | model = maybe_shard_model(model, device=torch.device("cpu")) 84 | 85 | 86 | def test_unequal_dims(): 87 | sgd_config_1 = OptimizerConfig(sgd=SgdConfig(lr=0.02)) 88 | sgd_config_2 = OptimizerConfig(sgd=SgdConfig(lr=0.05)) 89 | table0 = EmbeddingBagConfig( 90 | name="table0", 91 | num_embeddings=10_000, 92 | embedding_dim=128, 93 | optimizer=sgd_config_1, 94 | data_type=DataType.FP32, 95 | ) 96 | table1 = EmbeddingBagConfig( 97 | name="table1", 98 | num_embeddings=10_000, 99 | embedding_dim=64, 100 | optimizer=sgd_config_2, 101 | data_type=DataType.FP32, 102 | ) 103 | 104 | with pytest.raises(ValidationError): 105 | _ = TwhinEmbeddingsConfig( 106 | tables=[table0, table1], 107 | ) 108 | -------------------------------------------------------------------------------- /core/losses.py: -------------------------------------------------------------------------------- 1 | """Loss functions -- including multi task ones.""" 2 | 3 | import typing 4 | 5 | from tml.core.loss_type import LossType 6 | from tml.ml_logging.torch_logging import logging 7 | 8 | import torch 9 | 10 | 11 | def _maybe_warn(reduction: str): 12 | """ 13 | Warning for reduction different than mean. 14 | """ 15 | if reduction != "mean": 16 | logging.warn( 17 | f"For the same global_batch_size, the gradient in DDP is guaranteed to be equal," 18 | f"to the gradient without DDP only for mean reduction. If you need this property for" 19 | f"the provided reduction {reduction}, it needs to be implemented." 20 | ) 21 | 22 | 23 | def build_loss( 24 | loss_type: LossType, 25 | reduction="mean", 26 | ): 27 | _maybe_warn(reduction) 28 | f = _LOSS_TYPE_TO_FUNCTION[loss_type] 29 | 30 | def loss_fn(logits, labels): 31 | return f(logits, labels.type_as(logits), reduction=reduction) 32 | 33 | return loss_fn 34 | 35 | 36 | def get_global_loss_detached(local_loss, reduction="mean"): 37 | """ 38 | Perform all_reduce to obtain the global loss function using the provided reduction. 39 | :param local_loss: The local loss of the current rank. 40 | :param reduction: The reduction to use for all_reduce. Should match the reduction used by DDP. 41 | :return: The reduced & detached global loss. 42 | """ 43 | if reduction != "mean": 44 | logging.warn( 45 | f"The reduction used in this function should be the same as the one used by " 46 | f"the DDP model. By default DDP uses mean, So ensure that DDP is appropriately" 47 | f"modified for reduction {reduction}." 48 | ) 49 | 50 | if reduction not in ["mean", "sum"]: 51 | raise ValueError(f"Reduction {reduction} is currently unsupported.") 52 | 53 | global_loss = local_loss.detach() 54 | 55 | if reduction == "mean": 56 | global_loss.div_(torch.distributed.get_world_size()) 57 | 58 | torch.distributed.all_reduce(global_loss) 59 | return global_loss 60 | 61 | 62 | def build_multi_task_loss( 63 | loss_type: LossType, 64 | tasks: typing.List[str], 65 | task_loss_reduction="mean", 66 | global_reduction="mean", 67 | pos_weights=None, 68 | ): 69 | _maybe_warn(global_reduction) 70 | _maybe_warn(task_loss_reduction) 71 | f = _LOSS_TYPE_TO_FUNCTION[loss_type] 72 | 73 | loss_reduction_fns = { 74 | "mean": torch.mean, 75 | "sum": torch.sum, 76 | "min": torch.min, 77 | "max": torch.max, 78 | "median": torch.median, 79 | } 80 | 81 | def loss_fn(logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor): 82 | if pos_weights is None: 83 | torch_weights = torch.ones([len(tasks)]) 84 | else: 85 | torch_weights = torch.tensor(pos_weights) 86 | 87 | losses = {} 88 | for task_idx, task in enumerate(tasks): 89 | task_logits = logits[:, task_idx] 90 | label = labels[:, task_idx].type_as(task_logits) 91 | 92 | loss = f( 93 | task_logits, 94 | label, 95 | reduction=task_loss_reduction, 96 | pos_weight=torch_weights[task_idx], 97 | weight=weights[:, task_idx], 98 | ) 99 | losses[f"loss/{task}"] = loss 100 | 101 | losses["loss"] = loss_reduction_fns[global_reduction](torch.stack(list(losses.values()))) 102 | return losses 103 | 104 | return loss_fn 105 | 106 | 107 | _LOSS_TYPE_TO_FUNCTION = { 108 | LossType.BCE_WITH_LOGITS: torch.nn.functional.binary_cross_entropy_with_logits 109 | } 110 | -------------------------------------------------------------------------------- /images/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.4.0 2 | aiofiles==22.1.0 3 | aiohttp==3.8.3 4 | aiosignal==1.3.1 5 | appdirs==1.4.4 6 | arrow==1.2.3 7 | asttokens==2.2.1 8 | astunparse==1.6.3 9 | async-timeout==4.0.2 10 | attrs==22.1.0 11 | backcall==0.2.0 12 | black==22.6.0 13 | cachetools==5.3.0 14 | cblack==22.6.0 15 | certifi==2022.12.7 16 | cfgv==3.3.1 17 | charset-normalizer==2.1.1 18 | click==8.1.3 19 | cmake==3.25.0 20 | Cython==0.29.32 21 | decorator==5.1.1 22 | distlib==0.3.6 23 | distro==1.8.0 24 | dm-tree==0.1.6 25 | docker==6.0.1 26 | docker-pycreds==0.4.0 27 | docstring-parser==0.8.1 28 | exceptiongroup==1.1.0 29 | executing==1.2.0 30 | fbgemm-gpu-cpu==0.3.2 31 | filelock==3.8.2 32 | fire==0.5.0 33 | flatbuffers==1.12 34 | frozenlist==1.3.3 35 | fsspec==2022.11.0 36 | gast==0.4.0 37 | gcsfs==2022.11.0 38 | gitdb==4.0.10 39 | GitPython==3.1.31 40 | google-api-core==2.8.2 41 | google-auth==2.16.0 42 | google-auth-oauthlib==0.4.6 43 | google-cloud-core==2.3.2 44 | google-cloud-storage==2.7.0 45 | google-crc32c==1.5.0 46 | google-pasta==0.2.0 47 | google-resumable-media==2.4.1 48 | googleapis-common-protos==1.56.4 49 | grpcio==1.51.1 50 | h5py==3.8.0 51 | hypothesis==6.61.0 52 | identify==2.5.17 53 | idna==3.4 54 | importlib-metadata==6.0.0 55 | iniconfig==2.0.0 56 | iopath==0.1.10 57 | ipdb==0.13.11 58 | ipython==8.10.0 59 | jedi==0.18.2 60 | Jinja2==3.1.2 61 | keras==2.9.0 62 | Keras-Preprocessing==1.1.2 63 | libclang==15.0.6.1 64 | libcst==0.4.9 65 | Markdown==3.4.1 66 | MarkupSafe==2.1.1 67 | matplotlib-inline==0.1.6 68 | moreorless==0.4.0 69 | multidict==6.0.4 70 | mypy==1.0.1 71 | mypy-extensions==0.4.3 72 | nest-asyncio==1.5.6 73 | ninja==1.11.1 74 | nodeenv==1.7.0 75 | numpy==1.22.0 76 | nvidia-cublas-cu11==11.10.3.66 77 | nvidia-cuda-nvrtc-cu11==11.7.99 78 | nvidia-cuda-runtime-cu11==11.7.99 79 | nvidia-cudnn-cu11==8.5.0.96 80 | oauthlib==3.2.2 81 | opt-einsum==3.3.0 82 | packaging==22.0 83 | pandas==1.5.3 84 | parso==0.8.3 85 | pathspec==0.11.0 86 | pathtools==0.1.2 87 | pexpect==4.8.0 88 | pickleshare==0.7.5 89 | platformdirs==3.0.0 90 | pluggy==1.0.0 91 | portalocker==2.6.0 92 | portpicker==1.5.2 93 | pre-commit==3.0.4 94 | prompt-toolkit==3.0.36 95 | protobuf==3.20.2 96 | psutil==5.9.4 97 | ptyprocess==0.7.0 98 | pure-eval==0.2.2 99 | pyarrow==10.0.1 100 | pyasn1==0.4.8 101 | pyasn1-modules==0.2.8 102 | pydantic==1.9.0 103 | pyDeprecate==0.3.2 104 | Pygments==2.14.0 105 | pyparsing==3.0.9 106 | pyre-extensions==0.0.27 107 | pytest==7.2.1 108 | pytest-mypy==0.10.3 109 | python-dateutil==2.8.2 110 | pytz==2022.6 111 | PyYAML==6.0.0 112 | requests==2.28.1 113 | requests-oauthlib==1.3.1 114 | rsa==4.9 115 | scikit-build==0.16.3 116 | sentry-sdk==1.16.0 117 | setproctitle==1.3.2 118 | six==1.16.0 119 | smmap==5.0.0 120 | sortedcontainers==2.4.0 121 | stack-data==0.6.2 122 | stdlibs==2022.10.9 123 | tabulate==0.9.0 124 | tensorboard==2.9.0 125 | tensorboard-data-server==0.6.1 126 | tensorboard-plugin-wit==1.8.1 127 | tensorflow==2.9.3 128 | tensorflow-estimator==2.9.0 129 | tensorflow-io-gcs-filesystem==0.30.0 130 | termcolor==2.2.0 131 | toml==0.10.2 132 | tomli==2.0.1 133 | torch==1.13.1 134 | torchmetrics==0.11.0 135 | torchrec==0.3.2 136 | torchsnapshot==0.1.0 137 | torchx==0.3.0 138 | tqdm==4.64.1 139 | trailrunner==1.2.1 140 | traitlets==5.9.0 141 | typing-inspect==0.8.0 142 | typing_extensions==4.4.0 143 | urllib3==1.26.13 144 | usort==1.0.5 145 | virtualenv==20.19.0 146 | wandb==0.13.11 147 | wcwidth==0.2.6 148 | websocket-client==1.4.2 149 | Werkzeug==2.2.3 150 | wrapt==1.14.1 151 | yarl==1.8.2 152 | zipp==3.12.1 153 | -------------------------------------------------------------------------------- /core/metric_mixin.py: -------------------------------------------------------------------------------- 1 | """ 2 | Mixin that requires a transform to munge output dictionary of tensors a 3 | model produces to a form that the torchmetrics.Metric.update expects. 4 | 5 | By unifying on our signature for `update`, we can also now use 6 | torchmetrics.MetricCollection which requires all metrics have 7 | the same call signature. 8 | 9 | To use, override this with a transform that munges `outputs` 10 | into a kwargs dict that the inherited metric.update accepts. 11 | 12 | Here are two examples of how to extend torchmetrics.SumMetric so that it accepts 13 | an output dictionary of tensors and munges it to what SumMetric expects (single `value`) 14 | for its update method. 15 | 16 | 1. Using as a mixin to inherit from or define a new metric class. 17 | 18 | class Count(MetricMixin, SumMetric): 19 | def transform(self, outputs): 20 | return {'value': 1} 21 | 22 | 2. Redefine an existing metric class. 23 | 24 | SumMetric = prepend_transform(SumMetric, lambda outputs: {'value': 1}) 25 | 26 | """ 27 | from abc import abstractmethod 28 | from typing import Callable, Dict, List 29 | 30 | from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined] 31 | 32 | import torch 33 | import torchmetrics 34 | 35 | 36 | class MetricMixin: 37 | @abstractmethod 38 | def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict: 39 | ... 40 | 41 | def update(self, outputs: Dict[str, torch.Tensor]): 42 | results = self.transform(outputs) 43 | # Do not try to update if any tensor is empty as a result of stratification. 44 | for value in results.values(): 45 | if torch.is_tensor(value) and not value.nelement(): 46 | return 47 | super().update(**results) 48 | 49 | 50 | class TaskMixin: 51 | def __init__(self, task_idx: int = -1, **kwargs): 52 | super().__init__(**kwargs) 53 | self._task_idx = task_idx 54 | 55 | 56 | class StratifyMixin: 57 | def __init__( 58 | self, 59 | stratifier=None, 60 | **kwargs, 61 | ): 62 | super().__init__(**kwargs) 63 | self._stratifier = stratifier 64 | 65 | def maybe_apply_stratification( 66 | self, outputs: Dict[str, torch.Tensor], value_names: List[str] 67 | ) -> Dict[str, torch.Tensor]: 68 | """Pick out examples with values for which the stratifier feature is equal to a specific stratifier indicator value.""" 69 | outputs = outputs.copy() 70 | if not self._stratifier: 71 | return outputs 72 | stratifiers = outputs.get("stratifiers") 73 | if not stratifiers: 74 | return outputs 75 | if stratifiers.get(self._stratifier.name) is None: 76 | return outputs 77 | 78 | mask = torch.flatten(outputs["stratifiers"][self._stratifier.name] == self._stratifier.value) 79 | target_slice = torch.squeeze(mask.nonzero(), -1) 80 | for value_name in value_names: 81 | target = outputs[value_name] 82 | outputs[value_name] = torch.index_select(target, 0, target_slice) 83 | return outputs 84 | 85 | 86 | def prepend_transform(base_metric: torchmetrics.Metric, transform: Callable): 87 | """Returns new class using MetricMixin and given base_metric. 88 | 89 | Functionally the same using inheritance, just saves some lines of code 90 | if no need for class attributes. 91 | 92 | """ 93 | 94 | def transform_method(_self, *args, **kwargs): 95 | return transform(*args, **kwargs) 96 | 97 | return type( 98 | base_metric.__name__, 99 | ( 100 | MetricMixin, 101 | base_metric, 102 | ), 103 | {"transform": transform_method}, 104 | ) 105 | -------------------------------------------------------------------------------- /projects/twhin/run.py: -------------------------------------------------------------------------------- 1 | from absl import app, flags 2 | import json 3 | from typing import Optional 4 | import os 5 | import sys 6 | 7 | import torch 8 | 9 | # isort: on 10 | from tml.common.device import setup_and_get_device 11 | from tml.common.utils import setup_configuration 12 | import tml.core.custom_training_loop as ctl 13 | import tml.machines.environment as env 14 | from tml.projects.twhin.models.models import apply_optimizers, TwhinModel, TwhinModelAndLoss 15 | from tml.model import maybe_shard_model 16 | from tml.projects.twhin.metrics import create_metrics 17 | from tml.projects.twhin.config import TwhinConfig 18 | from tml.projects.twhin.data.data import create_dataset 19 | from tml.projects.twhin.optimizer import build_optimizer 20 | 21 | from tml.ml_logging.torch_logging import logging 22 | 23 | import torch.distributed as dist 24 | from torch.nn import functional as F 25 | from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward 26 | from torchrec.distributed.model_parallel import get_module 27 | 28 | FLAGS = flags.FLAGS 29 | 30 | flags.DEFINE_bool("overwrite_save_dir", False, "Whether to clear preexisting save directories.") 31 | flags.DEFINE_string("save_dir", None, "If provided, overwrites the save directory.") 32 | flags.DEFINE_string("config_yaml_path", None, "Path to hyperparameters for model.") 33 | flags.DEFINE_string("task", None, "Task to run if this is local. Overrides TF_CONFIG etc.") 34 | 35 | 36 | def run( 37 | all_config: TwhinConfig, 38 | save_dir: Optional[str] = None, 39 | ): 40 | train_dataset = create_dataset(all_config.train_data, all_config.model) 41 | 42 | if env.is_reader(): 43 | train_dataset.serve() 44 | if env.is_chief(): 45 | device = setup_and_get_device(tf_ok=False) 46 | logging.info(f"device: {device}") 47 | logging.info(f"WORLD_SIZE: {dist.get_world_size()}") 48 | 49 | # validation_dataset = create_dataset(all_config.validation_data, all_config.model) 50 | 51 | global_batch_size = all_config.train_data.per_replica_batch_size * dist.get_world_size() 52 | 53 | metrics = create_metrics(device) 54 | 55 | model = TwhinModel(all_config.model, all_config.train_data) 56 | apply_optimizers(model, all_config.model) 57 | model = maybe_shard_model(model, device=device) 58 | optimizer, scheduler = build_optimizer(model=model, config=all_config.model) 59 | 60 | loss_fn = F.binary_cross_entropy_with_logits 61 | model_and_loss = TwhinModelAndLoss( 62 | model, loss_fn, data_config=all_config.train_data, device=device 63 | ) 64 | 65 | ctl.train( 66 | model=model_and_loss, 67 | optimizer=optimizer, 68 | device=device, 69 | save_dir=save_dir, 70 | logging_interval=all_config.training.train_log_every_n, 71 | train_steps=all_config.training.num_train_steps, 72 | checkpoint_frequency=all_config.training.checkpoint_every_n, 73 | dataset=train_dataset.dataloader(remote=False), 74 | worker_batch_size=global_batch_size, 75 | num_workers=0, 76 | scheduler=scheduler, 77 | initial_checkpoint_dir=all_config.training.initial_checkpoint_dir, 78 | gradient_accumulation=all_config.training.gradient_accumulation, 79 | ) 80 | 81 | 82 | def main(argv): 83 | logging.info("Starting") 84 | 85 | logging.info(f"parsing config from {FLAGS.config_yaml_path}...") 86 | all_config = setup_configuration( # type: ignore[var-annotated] 87 | TwhinConfig, 88 | yaml_path=FLAGS.config_yaml_path, 89 | ) 90 | 91 | run( 92 | all_config, 93 | save_dir=FLAGS.save_dir, 94 | ) 95 | 96 | 97 | if __name__ == "__main__": 98 | app.run(main) 99 | -------------------------------------------------------------------------------- /metrics/aggregation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains aggregation metrics. 3 | """ 4 | from typing import Tuple, Union 5 | 6 | import torch 7 | import torchmetrics 8 | 9 | 10 | def update_mean( 11 | current_mean: torch.Tensor, 12 | current_weight_sum: torch.Tensor, 13 | value: torch.Tensor, 14 | weight: torch.Tensor, 15 | ) -> Tuple[torch.Tensor, torch.Tensor]: 16 | """ 17 | Update the mean according to Welford formula: 18 | https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version. 19 | See also https://nullbuffer.com/articles/welford_algorithm.html for more information. 20 | Args: 21 | current_mean: The value of the current accumulated mean. 22 | current_weight_sum: The current weighted sum. 23 | value: The new value that needs to be added to get a new mean. 24 | weight: The weights for the new value. 25 | 26 | Returns: The updated mean and updated weighted sum. 27 | 28 | """ 29 | weight = torch.broadcast_to(weight, value.shape) 30 | 31 | # Avoiding (on purpose) in-place operation when using += in case 32 | # current_mean and current_weight_sum share the same storage 33 | current_weight_sum = current_weight_sum + torch.sum(weight) 34 | current_mean = current_mean + torch.sum((weight / current_weight_sum) * (value - current_mean)) 35 | return current_mean, current_weight_sum 36 | 37 | 38 | def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor: 39 | """ 40 | Merge the state from multiple workers. 41 | Args: 42 | state: A tensor with the first dimension indicating workers. 43 | 44 | Returns: The accumulated mean from all workers. 45 | 46 | """ 47 | mean, weight_sum = update_mean( 48 | current_mean=torch.as_tensor(0.0, dtype=state.dtype, device=state.device), 49 | current_weight_sum=torch.as_tensor(0.0, dtype=state.dtype, device=state.device), 50 | value=state[:, 0], 51 | weight=state[:, 1], 52 | ) 53 | return torch.stack([mean, weight_sum]) 54 | 55 | 56 | class StableMean(torchmetrics.Metric): 57 | """ 58 | This implements a numerical stable mean metrics computation using Welford algorithm according to 59 | https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version. 60 | For example when using float32, the algorithm will give a valid output even if the "sum" is larger 61 | than the maximum float32 as far as the mean is within the limit of float32. 62 | See also https://nullbuffer.com/articles/welford_algorithm.html for more information. 63 | """ 64 | 65 | def __init__(self, **kwargs): 66 | """ 67 | Args: 68 | **kwargs: Additional parameters supported by all torchmetrics.Metric. 69 | """ 70 | super().__init__(**kwargs) 71 | self.add_state( 72 | "mean_and_weight_sum", 73 | default=torch.zeros(2), 74 | dist_reduce_fx=stable_mean_dist_reduce_fn, 75 | ) 76 | 77 | def update(self, value: torch.Tensor, weight: Union[float, torch.Tensor] = 1.0) -> None: 78 | """ 79 | Update the current mean. 80 | Args: 81 | value: Value to update the mean with. 82 | weight: weight to use. Shape should be broadcastable to that of value. 83 | """ 84 | mean, weight_sum = self.mean_and_weight_sum[0], self.mean_and_weight_sum[1] 85 | 86 | if not isinstance(weight, torch.Tensor): 87 | weight = torch.as_tensor(weight, dtype=value.dtype, device=value.device) 88 | 89 | self.mean_and_weight_sum[0], self.mean_and_weight_sum[1] = update_mean( 90 | mean, weight_sum, value, torch.as_tensor(weight) 91 | ) 92 | 93 | def compute(self) -> torch.Tensor: 94 | """ 95 | Compute and return the accumulated mean. 96 | """ 97 | return self.mean_and_weight_sum[0] 98 | -------------------------------------------------------------------------------- /common/log_weights.py: -------------------------------------------------------------------------------- 1 | """For logging model weights.""" 2 | import itertools 3 | from typing import Callable, Dict, List, Optional, Union 4 | 5 | from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined] 6 | import torch 7 | import torch.distributed as dist 8 | from torchrec.distributed.model_parallel import DistributedModelParallel 9 | 10 | 11 | def weights_to_log( 12 | model: torch.nn.Module, 13 | how_to_log: Optional[Union[Callable, Dict[str, Callable]]] = None, 14 | ): 15 | """Creates dict of reduced weights to log to give sense of training. 16 | 17 | Args: 18 | model: model to traverse. 19 | how_to_log: if a function, then applies this to every parameter, if a dict 20 | then only applies and logs specified parameters. 21 | 22 | """ 23 | if not how_to_log: 24 | return 25 | 26 | to_log = dict() 27 | named_parameters = model.named_parameters() 28 | logging.info(f"Using DMP: {isinstance(model, DistributedModelParallel)}") 29 | if isinstance(model, DistributedModelParallel): 30 | named_parameters = itertools.chain( 31 | named_parameters, model._dmp_wrapped_module.named_parameters() 32 | ) 33 | logging.info( 34 | f"Using dmp parameters: {list(name for name, _ in model._dmp_wrapped_module.named_parameters())}" 35 | ) 36 | for param_name, params in named_parameters: 37 | if callable(how_to_log): 38 | how = how_to_log 39 | else: 40 | how = how_to_log.get(param_name) # type: ignore[assignment] 41 | if not how: 42 | continue # type: ignore 43 | to_log[f"model/{how.__name__}/{param_name}"] = how(params.detach()).cpu().numpy() 44 | return to_log 45 | 46 | 47 | def log_ebc_norms( 48 | model_state_dict, 49 | ebc_keys: List[str], 50 | sample_size: int = 4_000_000, 51 | ) -> Dict[str, torch.Tensor]: 52 | """Logs the norms of the embedding tables as specified by ebc_keys. 53 | As of now, log average norm per rank. 54 | 55 | Args: 56 | model_state_dict: model.state_dict() 57 | ebc_keys: list of embedding keys from state_dict to log. Must contain full name, 58 | i.e. model.embeddings.ebc.embedding_bags.meta__user_id.weight 59 | sample_size: Limits number of rows per rank to compute average on to avoid OOM. 60 | """ 61 | norm_logs = dict() 62 | for emb_key in ebc_keys: 63 | norms = (torch.ones(1, dtype=torch.float32) * -1).to(torch.device(f"cuda:{dist.get_rank()}")) 64 | if emb_key in model_state_dict: 65 | emb_weight = model_state_dict[emb_key] 66 | try: 67 | emb_weight_tensor = emb_weight.local_tensor() 68 | except AttributeError as e: 69 | logging.info(e) 70 | emb_weight_tensor = emb_weight 71 | logging.info("Running Tensor.detach()") 72 | emb_weight_tensor = emb_weight_tensor.detach() 73 | sample_mask = torch.randperm(emb_weight_tensor.shape[0])[ 74 | : min(sample_size, emb_weight_tensor.shape[0]) 75 | ] 76 | # WARNING: .cpu() transfer executes malloc that may be the cause of memory leaks 77 | # Change sample_size if the you observe frequent OOM errors or remove weight logging. 78 | norms = emb_weight_tensor[sample_mask].cpu().norm(dim=1).to(torch.float32) 79 | logging.info(f"Norm shape before reduction: {norms.shape}", rank=-1) 80 | norms = norms.mean().to(torch.device(f"cuda:{dist.get_rank()}")) 81 | 82 | all_norms = [ 83 | torch.zeros(1, dtype=norms.dtype).to(norms.device) for _ in range(dist.get_world_size()) 84 | ] 85 | dist.all_gather(all_norms, norms) 86 | for idx, norm in enumerate(all_norms): 87 | if norm != -1.0: 88 | norm_logs[f"{emb_key}-norm-{idx}"] = norm 89 | logging.info(f"Norm Logs are {norm_logs}") 90 | return norm_logs 91 | -------------------------------------------------------------------------------- /projects/home/recap/README.md: -------------------------------------------------------------------------------- 1 | This project is the "heavy ranker" used on the "For You" timeline. This is used to generate the ranking of Tweet after candidate retrieval and light ranker (note the final ordering of the Tweet is not directly the highest -> lowest scoring, because after scoring other heuristics are used). 2 | 3 | This model captures the ranking model used for the majority of users of Twitter "For You" timeline in early March 2023. Due to the need to make sure this runs independently from other parts of Twitter codebase, there may be small differences from the production model. 4 | 5 | The model receives various features, describing the Tweet and the user whose timeline is being constructed as input (see FEATURES.md for more details). The model outputs multiple binary predictions about how the user will respond if shown the Tweet. 6 | 7 | 8 | Those are: 9 | "recap.engagement.is_favorited": The probability the user will favorite the Tweet. 10 | "recap.engagement.is_good_clicked_convo_desc_favorited_or_replied": The probability the user will click into the conversation of this Tweet and reply or Like a Tweet. 11 | "recap.engagement.is_good_clicked_convo_desc_v2": The probability the user will click into the conversation of this Tweet and stay there for at least 2 minutes. 12 | "recap.engagement.is_negative_feedback_v2": The probability the user will react negatively (requesting "show less often" on the Tweet or author, block or mute the Tweet author) 13 | "recap.engagement.is_profile_clicked_and_profile_engaged": The probability the user opens the Tweet author profile and Likes or replies to a Tweet. 14 | "recap.engagement.is_replied": The probability the user replies to the Tweet. 15 | "recap.engagement.is_replied_reply_engaged_by_author": The probability the user replies to the Tweet and this reply is engaged by the Tweet author. 16 | "recap.engagement.is_report_tweet_clicked": The probability the user will click Report Tweet. 17 | "recap.engagement.is_retweeted": The probability the user will ReTweet the Tweet. 18 | "recap.engagement.is_video_playback_50": The probability (for a video Tweet) that the user will watch at least half of the video 19 | 20 | For ranking the candidates these predictions are combined into a score by weighting them: 21 | "recap.engagement.is_favorited": 0.5 22 | "recap.engagement.is_good_clicked_convo_desc_favorited_or_replied": 11* (the maximum prediction from these two "good click" features is used and weighted by 11, the other prediction is ignored). 23 | "recap.engagement.is_good_clicked_convo_desc_v2": 11* 24 | "recap.engagement.is_negative_feedback_v2": -74 25 | "recap.engagement.is_profile_clicked_and_profile_engaged": 12 26 | "recap.engagement.is_replied": 27 27 | "recap.engagement.is_replied_reply_engaged_by_author": 75 28 | "recap.engagement.is_report_tweet_clicked": -369 29 | "recap.engagement.is_retweeted": 1 30 | "recap.engagement.is_video_playback_50": 0.005 31 | 32 | 33 | We cannot release the real training data due to privacy restrictions. However, we have included a script to generate random data to ensure you can run the model training code. 34 | 35 | To try training the model (assuming you have already followed the repo setup instructions and are inside a virtualenv). 36 | 37 | Run 38 | $ ./projects/home/recap/scripts/create_random_data.sh 39 | 40 | This will create some random data (in $HOME/tmp/recap_local_random_data). 41 | 42 | $ ./projects/home/recap/scripts/run_local.sh 43 | 44 | This will train the model (for a small number of iterations). Checkpoints and logs will be written to $HOME/tmp/runs/recap_local_debug. 45 | 46 | The model training is configured through a yaml file (./projects/home/recap/config/local_prod.yaml). 47 | 48 | The model architecture is a parallel masknet (https://arxiv.org/abs/2102.07619). 49 | -------------------------------------------------------------------------------- /projects/home/recap/main.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | from typing import Callable, List, Optional, Tuple 4 | import tensorflow as tf 5 | 6 | import tml.common.checkpointing.snapshot as snapshot_lib 7 | from tml.common.device import setup_and_get_device 8 | from tml.core import config as tml_config_mod 9 | import tml.core.custom_training_loop as ctl 10 | from tml.core import debug_training_loop 11 | from tml.core import losses 12 | from tml.core.loss_type import LossType 13 | from tml.model import maybe_shard_model 14 | 15 | 16 | import tml.projects.home.recap.data.dataset as ds 17 | import tml.projects.home.recap.config as recap_config_mod 18 | import tml.projects.home.recap.optimizer as optimizer_mod 19 | 20 | 21 | # from tml.projects.home.recap import feature 22 | import tml.projects.home.recap.model as model_mod 23 | import torchmetrics as tm 24 | import torch 25 | import torch.distributed as dist 26 | from torchrec.distributed.model_parallel import DistributedModelParallel 27 | 28 | from absl import app, flags, logging 29 | 30 | flags.DEFINE_string("config_path", None, "Path to hyperparameters for model.") 31 | flags.DEFINE_bool("debug_loop", False, "Run with debug loop (slow)") 32 | 33 | FLAGS = flags.FLAGS 34 | 35 | 36 | def run(unused_argv: str, data_service_dispatcher: Optional[str] = None): 37 | print("#" * 100) 38 | 39 | config = tml_config_mod.load_config_from_yaml(recap_config_mod.RecapConfig, FLAGS.config_path) 40 | logging.info("Config: %s", config.pretty_print()) 41 | 42 | device = setup_and_get_device() 43 | 44 | # Always enable tensorfloat on supported devices. 45 | torch.backends.cuda.matmul.allow_tf32 = True 46 | torch.backends.cudnn.allow_tf32 = True 47 | 48 | loss_fn = losses.build_multi_task_loss( 49 | loss_type=LossType.BCE_WITH_LOGITS, 50 | tasks=list(config.model.tasks.keys()), 51 | pos_weights=[task.pos_weight for task in config.model.tasks.values()], 52 | ) 53 | 54 | # Since the prod model doesn't use large embeddings, for now we won't support them. 55 | assert config.model.large_embeddings is None 56 | 57 | train_dataset = ds.RecapDataset( 58 | data_config=config.train_data, 59 | dataset_service=data_service_dispatcher, 60 | mode=recap_config_mod.JobMode.TRAIN, 61 | compression=config.train_data.dataset_service_compression, 62 | vocab_mapper=None, 63 | repeat=True, 64 | ) 65 | 66 | train_iterator = iter(train_dataset.to_dataloader()) 67 | 68 | torch_element_spec = train_dataset.torch_element_spec 69 | 70 | model = model_mod.create_ranking_model( 71 | data_spec=torch_element_spec[0], 72 | config=config, 73 | loss_fn=loss_fn, 74 | device=device, 75 | ) 76 | 77 | optimizer, scheduler = optimizer_mod.build_optimizer(model, config.optimizer, None) 78 | 79 | model = maybe_shard_model(model, device) 80 | 81 | datetime_str = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M") 82 | print(f"{datetime_str}\n", end="") 83 | 84 | if FLAGS.debug_loop: 85 | logging.warning("Running debug mode, slow!") 86 | train_mod = debug_training_loop 87 | else: 88 | train_mod = ctl 89 | 90 | train_mod.train( 91 | model=model, 92 | optimizer=optimizer, 93 | device=device, 94 | save_dir=config.training.save_dir, 95 | logging_interval=config.training.train_log_every_n, 96 | train_steps=config.training.num_train_steps, 97 | checkpoint_frequency=config.training.checkpoint_every_n, 98 | dataset=train_iterator, 99 | worker_batch_size=config.train_data.global_batch_size, 100 | enable_amp=False, 101 | initial_checkpoint_dir=config.training.initial_checkpoint_dir, 102 | gradient_accumulation=config.training.gradient_accumulation, 103 | scheduler=scheduler, 104 | ) 105 | 106 | 107 | if __name__ == "__main__": 108 | app.run(run) 109 | -------------------------------------------------------------------------------- /optimizers/optimizer.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple 2 | import math 3 | import bisect 4 | 5 | from tml.optimizers.config import ( 6 | LearningRate, 7 | OptimizerConfig, 8 | ) 9 | 10 | import torch 11 | from torch.optim import Optimizer 12 | from torch.optim.lr_scheduler import _LRScheduler 13 | from tml.ml_logging.torch_logging import logging 14 | 15 | 16 | def compute_lr(lr_config, step): 17 | """Compute a learning rate.""" 18 | if lr_config.constant is not None: 19 | return lr_config.constant 20 | elif lr_config.piecewise_constant is not None: 21 | return lr_config.piecewise_constant.learning_rate_values[ 22 | bisect.bisect_right(lr_config.piecewise_constant.learning_rate_boundaries, step) 23 | ] 24 | elif lr_config.linear_ramp_to_constant is not None: 25 | slope = ( 26 | lr_config.linear_ramp_to_constant.learning_rate 27 | / lr_config.linear_ramp_to_constant.num_ramp_steps 28 | ) 29 | return min(lr_config.linear_ramp_to_constant.learning_rate, slope * step) 30 | elif lr_config.linear_ramp_to_cosine is not None: 31 | cfg = lr_config.linear_ramp_to_cosine 32 | if step < cfg.num_ramp_steps: 33 | slope = cfg.learning_rate / cfg.num_ramp_steps 34 | return slope * step 35 | elif step <= cfg.final_num_steps: 36 | return cfg.final_learning_rate + (cfg.learning_rate - cfg.final_learning_rate) * 0.5 * ( 37 | 1.0 38 | + math.cos( 39 | math.pi * (step - cfg.num_ramp_steps) / (cfg.final_num_steps - cfg.num_ramp_steps) 40 | ) 41 | ) 42 | else: 43 | return cfg.final_learning_rate 44 | else: 45 | raise ValueError(f"No option selected in lr_config, passed {lr_config}") 46 | 47 | 48 | class LRShim(_LRScheduler): 49 | """Shim to get learning rates into a LRScheduler. 50 | 51 | This adheres to the torch.optim scheduler API and can be plugged anywhere that 52 | e.g. exponential decay can be used. 53 | """ 54 | 55 | def __init__( 56 | self, 57 | optimizer, 58 | lr_dict: Dict[str, LearningRate], 59 | last_epoch=-1, 60 | verbose=False, 61 | ): 62 | self.optimizer = optimizer 63 | self.lr_dict = lr_dict 64 | self.group_names = list(self.lr_dict.keys()) 65 | 66 | num_param_groups = sum(1 for _, _optim in optimizer._optims for _ in _optim.param_groups) 67 | if num_param_groups != len(lr_dict): 68 | raise ValueError( 69 | f"Optimizer had {len(optimizer.param_groups)}, but config had {len(lr_dict)}." 70 | ) 71 | 72 | super().__init__(optimizer, last_epoch, verbose) 73 | 74 | def get_lr(self): 75 | if not self._get_lr_called_within_step: 76 | logging.warn( 77 | "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", 78 | UserWarning, 79 | ) 80 | return self._get_closed_form_lr() 81 | 82 | def _get_closed_form_lr(self): 83 | return [compute_lr(lr_config, self.last_epoch) for lr_config in self.lr_dict.values()] 84 | 85 | 86 | def get_optimizer_class(optimizer_config: OptimizerConfig): 87 | if optimizer_config.adam is not None: 88 | return torch.optim.Adam 89 | elif optimizer_config.sgd is not None: 90 | return torch.optim.SGD 91 | elif optimizer_config.adagrad is not None: 92 | return torch.optim.Adagrad 93 | 94 | 95 | def build_optimizer( 96 | model: torch.nn.Module, optimizer_config: OptimizerConfig 97 | ) -> Tuple[Optimizer, _LRScheduler]: 98 | """Builds an optimizer and LR scheduler from an OptimizerConfig. 99 | Note: use this when you want the same optimizer and learning rate schedule for all your parameters. 100 | """ 101 | optimizer_class = get_optimizer_class(optimizer_config) 102 | optimizer = optimizer_class(model.parameters(), **optimizer_config.sgd.dict()) 103 | # We're passing everything in as one group here 104 | scheduler = LRShim(optimizer, lr_dict={"ALL_PARAMS": optimizer_config.learning_rate}) 105 | return optimizer, scheduler 106 | -------------------------------------------------------------------------------- /common/run_training.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import sys 4 | from typing import Optional 5 | 6 | from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined] 7 | from twitter.ml.tensorflow.experimental.distributed import utils 8 | 9 | import torch 10 | import torch.distributed.run 11 | 12 | 13 | def is_distributed_worker(): 14 | world_size = os.environ.get("WORLD_SIZE", None) 15 | rank = os.environ.get("RANK", None) 16 | return world_size is not None and rank is not None 17 | 18 | 19 | def maybe_run_training( 20 | train_fn, 21 | module_name, 22 | nproc_per_node: Optional[int] = None, 23 | num_nodes: Optional[int] = None, 24 | set_python_path_in_subprocess: bool = False, 25 | is_chief: Optional[bool] = False, 26 | **training_kwargs, 27 | ): 28 | """Wrapper function for single node, multi-GPU Pytorch training. 29 | 30 | If the necessary distributed Pytorch environment variables 31 | (WORLD_SIZE, RANK) have been set, then this function executes 32 | `train_fn(**training_kwargs)`. 33 | 34 | Otherwise, this function calls torchrun and points at the calling module 35 | `module_name`. After this call, the necessary environment variables are set 36 | and training will commence. 37 | 38 | Args: 39 | train_fn: The function that is responsible for training 40 | module_name: The name of the module that this function was called from; 41 | used to indicate torchrun entrypoint. 42 | nproc_per_node: Number of workers per node; supported values. 43 | num_nodes: Number of nodes, otherwise inferred from environment. 44 | is_chief: If process is running on chief. 45 | set_python_path_in_subprocess: A bool denoting whether to set PYTHONPATH. 46 | """ 47 | 48 | machines = utils.machine_from_env() 49 | if num_nodes is None: 50 | num_nodes = 1 51 | if machines.num_workers: 52 | num_nodes += machines.num_workers 53 | 54 | if is_distributed_worker(): 55 | # world_size, rank, etc are set; assuming any other env vars are set (checks to come) 56 | # start the actual training! 57 | train_fn(**training_kwargs) 58 | else: 59 | if nproc_per_node is None: 60 | if torch.cuda.is_available(): 61 | nproc_per_node = torch.cuda.device_count() 62 | else: 63 | nproc_per_node = machines.chief.num_accelerators 64 | 65 | # Rejoin all arguments to send back through torchrec 66 | # this is a temporary measure, will replace the os.system call 67 | # with torchrun API calls 68 | args = list(f"--{key}={val}" for key, val in training_kwargs.items()) 69 | 70 | cmd = [ 71 | "--nnodes", 72 | str(num_nodes), 73 | ] 74 | if nproc_per_node: 75 | cmd.extend(["--nproc_per_node", str(nproc_per_node)]) 76 | if num_nodes > 1: 77 | cluster_resolver = utils.cluster_resolver() 78 | backend_address = cluster_resolver.cluster_spec().task_address("chief", 0) 79 | cmd.extend( 80 | [ 81 | "--rdzv_backend", 82 | "c10d", 83 | "--rdzv_id", 84 | backend_address, 85 | ] 86 | ) 87 | # Set localhost on chief because of https://github.com/pytorch/pytorch/issues/79388 88 | if is_chief: 89 | cmd.extend(["--rdzv_endpoint", "localhost:2222"]) 90 | else: 91 | cmd.extend(["--rdzv_endpoint", backend_address]) 92 | else: 93 | cmd.append("--standalone") 94 | 95 | cmd.extend( 96 | [ 97 | str(module_name), 98 | *args, 99 | ] 100 | ) 101 | logging.info(f"""Distributed running with cmd: '{" ".join(cmd)}'""") 102 | 103 | # Call torchrun on this module; will spawn new processes and re-run this 104 | # function, eventually calling "train_fn". The following line sets the PYTHONPATH to accommodate 105 | # bazel stubbing for the main binary. 106 | if set_python_path_in_subprocess: 107 | subprocess.run(["torchrun"] + cmd, env={**os.environ, "PYTHONPATH": ":".join(sys.path)}) 108 | else: 109 | torch.distributed.run.main(cmd) 110 | -------------------------------------------------------------------------------- /projects/home/recap/model/mask_net.py: -------------------------------------------------------------------------------- 1 | """MaskNet: Wang et al. (https://arxiv.org/abs/2102.07619).""" 2 | 3 | from tml.projects.home.recap.model import config, mlp 4 | 5 | import torch 6 | 7 | 8 | def _init_weights(module): 9 | if isinstance(module, torch.nn.Linear): 10 | torch.nn.init.xavier_uniform_(module.weight) 11 | torch.nn.init.constant_(module.bias, 0) 12 | 13 | 14 | class MaskBlock(torch.nn.Module): 15 | def __init__( 16 | self, mask_block_config: config.MaskBlockConfig, input_dim: int, mask_input_dim: int 17 | ) -> None: 18 | super(MaskBlock, self).__init__() 19 | self.mask_block_config = mask_block_config 20 | output_size = mask_block_config.output_size 21 | 22 | if mask_block_config.input_layer_norm: 23 | self._input_layer_norm = torch.nn.LayerNorm(input_dim) 24 | else: 25 | self._input_layer_norm = None 26 | 27 | if mask_block_config.reduction_factor: 28 | aggregation_size = int(mask_input_dim * mask_block_config.reduction_factor) 29 | elif mask_block_config.aggregation_size is not None: 30 | aggregation_size = mask_block_config.aggregation_size 31 | else: 32 | raise ValueError("Need one of reduction factor or aggregation size.") 33 | 34 | self._mask_layer = torch.nn.Sequential( 35 | torch.nn.Linear(mask_input_dim, aggregation_size), 36 | torch.nn.ReLU(), 37 | torch.nn.Linear(aggregation_size, input_dim), 38 | ) 39 | self._mask_layer.apply(_init_weights) 40 | self._hidden_layer = torch.nn.Linear(input_dim, output_size) 41 | self._hidden_layer.apply(_init_weights) 42 | self._layer_norm = torch.nn.LayerNorm(output_size) 43 | 44 | def forward(self, net: torch.Tensor, mask_input: torch.Tensor): 45 | if self._input_layer_norm: 46 | net = self._input_layer_norm(net) 47 | hidden_layer_output = self._hidden_layer(net * self._mask_layer(mask_input)) 48 | return self._layer_norm(hidden_layer_output) 49 | 50 | 51 | class MaskNet(torch.nn.Module): 52 | def __init__(self, mask_net_config: config.MaskNetConfig, in_features: int): 53 | super().__init__() 54 | self.mask_net_config = mask_net_config 55 | mask_blocks = [] 56 | 57 | if mask_net_config.use_parallel: 58 | total_output_mask_blocks = 0 59 | for mask_block_config in mask_net_config.mask_blocks: 60 | mask_blocks.append(MaskBlock(mask_block_config, in_features, in_features)) 61 | total_output_mask_blocks += mask_block_config.output_size 62 | self._mask_blocks = torch.nn.ModuleList(mask_blocks) 63 | else: 64 | input_size = in_features 65 | for mask_block_config in mask_net_config.mask_blocks: 66 | mask_blocks.append(MaskBlock(mask_block_config, input_size, in_features)) 67 | input_size = mask_block_config.output_size 68 | 69 | self._mask_blocks = torch.nn.ModuleList(mask_blocks) 70 | total_output_mask_blocks = mask_block_config.output_size 71 | 72 | if mask_net_config.mlp: 73 | self._dense_layers = mlp.Mlp(total_output_mask_blocks, mask_net_config.mlp) 74 | self.out_features = mask_net_config.mlp.layer_sizes[-1] 75 | else: 76 | self.out_features = total_output_mask_blocks 77 | self.shared_size = total_output_mask_blocks 78 | 79 | def forward(self, inputs: torch.Tensor): 80 | if self.mask_net_config.use_parallel: 81 | mask_outputs = [] 82 | for mask_layer in self._mask_blocks: 83 | mask_outputs.append(mask_layer(mask_input=inputs, net=inputs)) 84 | # Share the outputs of the MaskBlocks. 85 | all_mask_outputs = torch.cat(mask_outputs, dim=1) 86 | output = ( 87 | all_mask_outputs 88 | if self.mask_net_config.mlp is None 89 | else self._dense_layers(all_mask_outputs)["output"] 90 | ) 91 | return {"output": output, "shared_layer": all_mask_outputs} 92 | else: 93 | net = inputs 94 | for mask_layer in self._mask_blocks: 95 | net = mask_layer(net=net, mask_input=inputs) 96 | # Share the output of the stacked MaskBlocks. 97 | output = net if self.mask_net_config.mlp is None else self._dense_layers[net]["output"] 98 | return {"output": output, "shared_layer": net} 99 | -------------------------------------------------------------------------------- /projects/home/recap/embedding/config.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | import tml.core.config as base_config 3 | from tml.optimizers import config as optimizer_config 4 | 5 | import pydantic 6 | 7 | 8 | class EmbeddingSnapshot(base_config.BaseConfig): 9 | """Configuration for Embedding snapshot""" 10 | 11 | emb_name: str = pydantic.Field( 12 | ..., description="Name of the embedding table from the loaded snapshot" 13 | ) 14 | embedding_snapshot_uri: str = pydantic.Field( 15 | ..., description="Path to torchsnapshot of the embedding" 16 | ) 17 | 18 | 19 | # https://pytorch.org/torchrec/torchrec.modules.html#torchrec.modules.embedding_configs.EmbeddingBagConfig 20 | class EmbeddingBagConfig(base_config.BaseConfig): 21 | """Configuration for EmbeddingBag.""" 22 | 23 | name: str = pydantic.Field(..., description="name of embedding bag") 24 | num_embeddings: int = pydantic.Field(..., description="size of embedding dictionary") 25 | embedding_dim: int = pydantic.Field(..., description="size of each embedding vector") 26 | pretrained: EmbeddingSnapshot = pydantic.Field(None, description="Snapshot properties") 27 | vocab: str = pydantic.Field( 28 | None, description="Directory to parquet files of mapping from entity ID to table index." 29 | ) 30 | 31 | 32 | class EmbeddingOptimizerConfig(base_config.BaseConfig): 33 | learning_rate: optimizer_config.LearningRate = pydantic.Field( 34 | None, description="learning rate scheduler for the EBC" 35 | ) 36 | init_learning_rate: float = pydantic.Field(description="initial learning rate for the EBC") 37 | # NB: Only sgd is supported right now and implicitly. 38 | # FBGemm only supports simple exact_sgd which only takes LR as an argument. 39 | 40 | 41 | class LargeEmbeddingsConfig(base_config.BaseConfig): 42 | """Configuration for EmbeddingBagCollection. 43 | 44 | The tables listed in this config are gathered into a single torchrec EmbeddingBagCollection. 45 | """ 46 | 47 | tables: List[EmbeddingBagConfig] = pydantic.Field(..., description="list of embedding tables") 48 | optimizer: EmbeddingOptimizerConfig 49 | tables_to_log: List[str] = pydantic.Field( 50 | None, description="list of embedding table names that we want to log during training" 51 | ) 52 | 53 | 54 | class StratifierConfig(base_config.BaseConfig): 55 | name: str 56 | index: int 57 | value: int 58 | 59 | 60 | class SmallEmbeddingBagConfig(base_config.BaseConfig): 61 | """Configuration for SmallEmbeddingBag.""" 62 | 63 | name: str = pydantic.Field(..., description="name of embedding bag") 64 | num_embeddings: int = pydantic.Field(..., description="size of embedding dictionary") 65 | embedding_dim: int = pydantic.Field(..., description="size of each embedding vector") 66 | index: int = pydantic.Field(..., description="index in the discrete tensor to look for") 67 | 68 | 69 | class SmallEmbeddingBagConfig(base_config.BaseConfig): 70 | """Configuration for SmallEmbeddingBag.""" 71 | 72 | name: str = pydantic.Field(..., description="name of embedding bag") 73 | num_embeddings: int = pydantic.Field(..., description="size of embedding dictionary") 74 | embedding_dim: int = pydantic.Field(..., description="size of each embedding vector") 75 | index: int = pydantic.Field(..., description="index in the discrete tensor to look for") 76 | 77 | 78 | class SmallEmbeddingsConfig(base_config.BaseConfig): 79 | """Configuration for SmallEmbeddingConfig. 80 | 81 | Here we can use discrete features that already are present in our TFRecords generated using 82 | segdense conversion as "home_recap_2022_discrete__segdense_vals" which are available in 83 | the model as "discrete_features", and embed a user-defined set of them with configurable 84 | dimensions and vocabulary sizes. 85 | 86 | Compared with LargeEmbedding, this config is for small embedding tables that can fit inside 87 | the model, whereas LargeEmbedding usually is meant to be hydrated outside the model at 88 | serving time due to size (>>1 GB). 89 | 90 | This small embeddings table uses the same optimizer as the rest of the model.""" 91 | 92 | tables: List[SmallEmbeddingBagConfig] = pydantic.Field( 93 | ..., description="list of embedding tables" 94 | ) 95 | -------------------------------------------------------------------------------- /projects/home/recap/data/util.py: -------------------------------------------------------------------------------- 1 | from typing import Mapping, Tuple, Union 2 | import torch 3 | import torchrec 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | 8 | def keyed_tensor_from_tensors_dict( 9 | tensor_map: Mapping[str, torch.Tensor] 10 | ) -> "torchrec.KeyedTensor": 11 | """ 12 | Convert a dictionary of torch tensor to torchrec keyed tensor 13 | Args: 14 | tensor_map: 15 | 16 | Returns: 17 | 18 | """ 19 | keys = list(tensor_map.keys()) 20 | # We expect batch size to be first dim. However, if we get a shape [Batch_size], 21 | # KeyedTensor will not find the correct batch_size. So, in those cases we make sure the shape is 22 | # [Batch_size x 1]. 23 | values = [ 24 | tensor_map[key] if len(tensor_map[key].shape) > 1 else torch.unsqueeze(tensor_map[key], -1) 25 | for key in keys 26 | ] 27 | return torchrec.KeyedTensor.from_tensor_list(keys, values) 28 | 29 | 30 | def _compute_jagged_tensor_from_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 31 | if tensor.is_sparse: 32 | x = tensor.coalesce() # Ensure that the indices are ordered. 33 | lengths = torch.bincount(x.indices()[0]) 34 | values = x.values() 35 | else: 36 | values = tensor 37 | lengths = torch.ones(tensor.shape[0], dtype=torch.int32, device=tensor.device) 38 | return values, lengths 39 | 40 | 41 | def jagged_tensor_from_tensor(tensor: torch.Tensor) -> "torchrec.JaggedTensor": 42 | """ 43 | Convert a torch tensor to torchrec jagged tensor. 44 | Note: Currently only support shape of [Batch_size] or [Batch_size x N] for dense tensors. 45 | For sparse tensor the shape of .values() should be [Batch_size] or [Batch_size x N]; the 46 | dense_shape of the sparse tensor can be arbitrary. 47 | Args: 48 | tensor: a torch (sparse) tensor. 49 | Returns: 50 | """ 51 | values, lengths = _compute_jagged_tensor_from_tensor(tensor) 52 | return torchrec.JaggedTensor(values=values, lengths=lengths) 53 | 54 | 55 | def keyed_jagged_tensor_from_tensors_dict( 56 | tensor_map: Mapping[str, torch.Tensor] 57 | ) -> "torchrec.KeyedJaggedTensor": 58 | """ 59 | Convert a dictionary of (sparse) torch tensors to torchrec keyed jagged tensor. 60 | Note: Currently only support shape of [Batch_size] or [Batch_size x 1] for dense tensors. 61 | For sparse tensor the shape of .values() should be [Batch_size] or [Batch_size x 1]; the 62 | dense_shape of the sparse tensor can be arbitrary. 63 | Args: 64 | tensor_map: 65 | 66 | Returns: 67 | 68 | """ 69 | 70 | if not tensor_map: 71 | return torchrec.KeyedJaggedTensor( 72 | keys=[], 73 | values=torch.zeros(0, dtype=torch.int), 74 | lengths=torch.zeros(0, dtype=torch.int), 75 | ) 76 | values = [] 77 | lengths = [] 78 | for tensor in tensor_map.values(): 79 | tensor_val, tensor_len = _compute_jagged_tensor_from_tensor(tensor) 80 | values.append(torch.squeeze(tensor_val)) 81 | lengths.append(tensor_len) 82 | 83 | values = torch.cat(values, axis=0) 84 | lengths = torch.cat(lengths, axis=0) 85 | 86 | return torchrec.KeyedJaggedTensor( 87 | keys=list(tensor_map.keys()), 88 | values=values, 89 | lengths=lengths, 90 | ) 91 | 92 | 93 | def _tf_to_numpy(tf_tensor: tf.Tensor) -> np.ndarray: 94 | return tf_tensor._numpy() # noqa 95 | 96 | 97 | def _dense_tf_to_torch(tensor: tf.Tensor, pin_memory: bool) -> torch.Tensor: 98 | tensor = _tf_to_numpy(tensor) 99 | # Pytorch does not support bfloat16, up cast to float32 to keep the same number of bits on exponent 100 | if tensor.dtype.name == "bfloat16": 101 | tensor = tensor.astype(np.float32) 102 | 103 | tensor = torch.from_numpy(tensor) 104 | if pin_memory: 105 | tensor = tensor.pin_memory() 106 | return tensor 107 | 108 | 109 | def sparse_or_dense_tf_to_torch( 110 | tensor: Union[tf.Tensor, tf.SparseTensor], pin_memory: bool 111 | ) -> torch.Tensor: 112 | if isinstance(tensor, tf.SparseTensor): 113 | tensor = torch.sparse_coo_tensor( 114 | _dense_tf_to_torch(tensor.indices, pin_memory).t(), 115 | _dense_tf_to_torch(tensor.values, pin_memory), 116 | torch.Size(_tf_to_numpy(tensor.dense_shape)), 117 | ) 118 | else: 119 | tensor = _dense_tf_to_torch(tensor, pin_memory) 120 | return tensor 121 | -------------------------------------------------------------------------------- /reader/dds.py: -------------------------------------------------------------------------------- 1 | """Dataset service orchestrated by a TFJob 2 | """ 3 | from typing import Optional 4 | import uuid 5 | 6 | from tml.ml_logging.torch_logging import logging 7 | import tml.machines.environment as env 8 | 9 | import packaging.version 10 | import tensorflow as tf 11 | 12 | try: 13 | import tensorflow_io as tfio 14 | except: 15 | pass 16 | from tensorflow.python.data.experimental.ops.data_service_ops import ( 17 | _from_dataset_id, 18 | _register_dataset, 19 | ) 20 | import torch.distributed as dist 21 | 22 | 23 | def maybe_start_dataset_service(): 24 | if not env.has_readers(): 25 | return 26 | 27 | if packaging.version.parse(tf.__version__) < packaging.version.parse("2.5"): 28 | raise Exception(f"maybe_distribute_dataset requires TF >= 2.5; got {tf.__version__}") 29 | 30 | if env.is_dispatcher(): 31 | logging.info(f"env.get_reader_port() = {env.get_reader_port()}") 32 | logging.info(f"env.get_dds_journaling_dir() = {env.get_dds_journaling_dir()}") 33 | work_dir = env.get_dds_journaling_dir() 34 | server = tf.data.experimental.service.DispatchServer( 35 | tf.data.experimental.service.DispatcherConfig( 36 | port=env.get_reader_port(), 37 | protocol="grpc", 38 | work_dir=work_dir, 39 | fault_tolerant_mode=bool(work_dir), 40 | ) 41 | ) 42 | server.join() 43 | 44 | elif env.is_reader(): 45 | logging.info(f"env.get_reader_port() = {env.get_reader_port()}") 46 | logging.info(f"env.get_dds_dispatcher_address() = {env.get_dds_dispatcher_address()}") 47 | logging.info(f"env.get_dds_worker_address() = {env.get_dds_worker_address()}") 48 | server = tf.data.experimental.service.WorkerServer( 49 | tf.data.experimental.service.WorkerConfig( 50 | port=env.get_reader_port(), 51 | dispatcher_address=env.get_dds_dispatcher_address(), 52 | worker_address=env.get_dds_worker_address(), 53 | protocol="grpc", 54 | ) 55 | ) 56 | server.join() 57 | 58 | 59 | def register_dataset( 60 | dataset: tf.data.Dataset, dataset_service: str, compression: Optional[str] = "AUTO" 61 | ): 62 | if dist.get_rank() == 0: 63 | dataset_id = _register_dataset( 64 | service=dataset_service, 65 | dataset=dataset, 66 | compression=compression, 67 | ) 68 | job_name = uuid.uuid4().hex[:8] 69 | id_and_job = [dataset_id.numpy(), job_name] 70 | logging.info(f"rank{dist.get_rank()}: Created dds job with {dataset_id.numpy()}, {job_name}") 71 | else: 72 | id_and_job = [None, None] 73 | 74 | dist.broadcast_object_list(id_and_job, src=0) 75 | return tuple(id_and_job) 76 | 77 | 78 | def distribute_from_dataset_id( 79 | dataset_service: str, 80 | dataset_id: int, 81 | job_name: Optional[str], 82 | compression: Optional[str] = "AUTO", 83 | prefetch: Optional[int] = tf.data.experimental.AUTOTUNE, 84 | ) -> tf.data.Dataset: 85 | logging.info(f"rank{dist.get_rank()}: Consuming dds job with {dataset_id}, {job_name}") 86 | dataset = _from_dataset_id( 87 | processing_mode="parallel_epochs", 88 | service=dataset_service, 89 | dataset_id=dataset_id, 90 | job_name=job_name, 91 | element_spec=None, 92 | compression=compression, 93 | ) 94 | if prefetch is not None: 95 | dataset = dataset.prefetch(prefetch) 96 | return dataset 97 | 98 | 99 | def maybe_distribute_dataset(dataset: tf.data.Dataset) -> tf.data.Dataset: 100 | """Torch-compatible and distributed-training-aware dataset service distributor. 101 | 102 | - rank 0 process will register the given dataset. 103 | - rank 0 process will broadcast job name and dataset id. 104 | - all rank processes will consume from the same job/dataset. 105 | 106 | Without this, dataset workers will try to serve 1 job per rank process and OOM. 107 | 108 | """ 109 | if not env.has_readers(): 110 | return dataset 111 | dataset_service = env.get_dds() 112 | 113 | logging.info(f"using DDS = {dataset_service}") 114 | dataset_id, job_name = register_dataset(dataset=dataset, dataset_service=dataset_service) 115 | dataset = distribute_from_dataset_id( 116 | dataset_service=dataset_service, dataset_id=dataset_id, job_name=job_name 117 | ) 118 | return dataset 119 | 120 | 121 | if __name__ == "__main__": 122 | maybe_start_dataset_service() 123 | -------------------------------------------------------------------------------- /projects/home/recap/model/feature_transform.py: -------------------------------------------------------------------------------- 1 | from typing import Mapping, Sequence, Union 2 | 3 | from tml.projects.home.recap.model.config import ( 4 | BatchNormConfig, 5 | DoubleNormLogConfig, 6 | FeaturizationConfig, 7 | LayerNormConfig, 8 | ) 9 | 10 | import torch 11 | 12 | 13 | def log_transform(x: torch.Tensor) -> torch.Tensor: 14 | """Safe log transform that works across both negative, zero, and positive floats.""" 15 | return torch.sign(x) * torch.log1p(torch.abs(x)) 16 | 17 | 18 | class BatchNorm(torch.nn.Module): 19 | def __init__(self, num_features: int, config: BatchNormConfig): 20 | super().__init__() 21 | self.layer = torch.nn.BatchNorm1d(num_features, affine=config.affine, momentum=config.momentum) 22 | 23 | def forward(self, x: torch.Tensor) -> torch.Tensor: 24 | return self.layer(x) 25 | 26 | 27 | class LayerNorm(torch.nn.Module): 28 | def __init__(self, normalized_shape: Union[int, Sequence[int]], config: LayerNormConfig): 29 | super().__init__() 30 | if config.axis != -1: 31 | raise NotImplementedError 32 | if config.center != config.scale: 33 | raise ValueError( 34 | f"Center and scale must match in torch, received {config.center}, {config.scale}" 35 | ) 36 | self.layer = torch.nn.LayerNorm( 37 | normalized_shape, eps=config.epsilon, elementwise_affine=config.center 38 | ) 39 | 40 | def forward(self, x: torch.Tensor) -> torch.Tensor: 41 | return self.layer(x) 42 | 43 | 44 | class Log1pAbs(torch.nn.Module): 45 | def __init__(self): 46 | super().__init__() 47 | 48 | def forward(self, x: torch.Tensor) -> torch.Tensor: 49 | return log_transform(x) 50 | 51 | 52 | class InputNonFinite(torch.nn.Module): 53 | def __init__(self, fill_value: float = 0): 54 | super().__init__() 55 | 56 | self.register_buffer( 57 | "fill_value", torch.as_tensor(fill_value, dtype=torch.float32), persistent=False 58 | ) 59 | 60 | def forward(self, x: torch.Tensor) -> torch.Tensor: 61 | return torch.where(torch.isfinite(x), x, self.fill_value) 62 | 63 | 64 | class Clamp(torch.nn.Module): 65 | def __init__(self, min_value: float, max_value: float): 66 | super().__init__() 67 | # Using buffer to make sure they are on correct device (and not moved every time). 68 | # Will also be part of state_dict. 69 | self.register_buffer( 70 | "min_value", torch.as_tensor(min_value, dtype=torch.float32), persistent=True 71 | ) 72 | self.register_buffer( 73 | "max_value", torch.as_tensor(max_value, dtype=torch.float32), persistent=True 74 | ) 75 | 76 | def forward(self, x: torch.Tensor) -> torch.Tensor: 77 | return torch.clamp(x, min=self.min_value, max=self.max_value) 78 | 79 | 80 | class DoubleNormLog(torch.nn.Module): 81 | """Performs a batch norm and clamp on continuous features followed by a layer norm on binary and continuous features.""" 82 | 83 | def __init__( 84 | self, 85 | input_shapes: Mapping[str, Sequence[int]], 86 | config: DoubleNormLogConfig, 87 | ): 88 | super().__init__() 89 | 90 | _before_concat_layers = [ 91 | InputNonFinite(), 92 | Log1pAbs(), 93 | ] 94 | if config.batch_norm_config: 95 | _before_concat_layers.append( 96 | BatchNorm(input_shapes["continuous"][-1], config.batch_norm_config) 97 | ) 98 | _before_concat_layers.append( 99 | Clamp(min_value=-config.clip_magnitude, max_value=config.clip_magnitude) 100 | ) 101 | self._before_concat_layers = torch.nn.Sequential(*_before_concat_layers) 102 | 103 | self.layer_norm = None 104 | if config.layer_norm_config: 105 | last_dim = input_shapes["continuous"][-1] + input_shapes["binary"][-1] 106 | self.layer_norm = LayerNorm(last_dim, config.layer_norm_config) 107 | 108 | def forward( 109 | self, continuous_features: torch.Tensor, binary_features: torch.Tensor 110 | ) -> torch.Tensor: 111 | x = self._before_concat_layers(continuous_features) 112 | x = torch.cat([x, binary_features], dim=1) 113 | if self.layer_norm: 114 | return self.layer_norm(x) 115 | return x 116 | 117 | 118 | def build_features_preprocessor( 119 | config: FeaturizationConfig, input_shapes: Mapping[str, Sequence[int]] 120 | ): 121 | """Trivial right now, but we will change in the future.""" 122 | return DoubleNormLog(input_shapes, config.double_norm_log_config) 123 | -------------------------------------------------------------------------------- /projects/home/recap/data/tfe_parsing.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import json 3 | 4 | from tml.projects.home.recap.data import config as recap_data_config 5 | 6 | from absl import logging 7 | import tensorflow as tf 8 | 9 | 10 | DEFAULTS_MAP = {"int64_list": 0, "float_list": 0.0, "bytes_list": ""} 11 | DTYPE_MAP = {"int64_list": tf.int64, "float_list": tf.float32, "bytes_list": tf.string} 12 | 13 | 14 | def create_tf_example_schema( 15 | data_config: recap_data_config.SegDenseSchema, 16 | segdense_schema, 17 | ): 18 | """Generate schema for deseralizing tf.Example. 19 | 20 | Args: 21 | segdense_schema: List of dicts of segdense features (includes feature_name, dtype, length). 22 | labels: List of strings denoting labels. 23 | 24 | Returns: 25 | A dictionary schema suitable for deserializing tf.Example. 26 | """ 27 | segdense_config = data_config.seg_dense_schema 28 | labels = list(data_config.tasks.keys()) 29 | used_features = ( 30 | segdense_config.features + list(segdense_config.renamed_features.values()) + labels 31 | ) 32 | logging.info(used_features) 33 | 34 | tfe_schema = {} 35 | for entry in segdense_schema: 36 | feature_name = entry["feature_name"] 37 | 38 | if feature_name in used_features: 39 | length = entry["length"] 40 | dtype = entry["dtype"] 41 | 42 | if feature_name in labels: 43 | logging.info(f"Label: feature name is {feature_name} type is {dtype}") 44 | tfe_schema[feature_name] = tf.io.FixedLenFeature( 45 | length, DTYPE_MAP[dtype], DEFAULTS_MAP[dtype] 46 | ) 47 | elif length == -1: 48 | tfe_schema[feature_name] = tf.io.VarLenFeature(DTYPE_MAP[dtype]) 49 | else: 50 | tfe_schema[feature_name] = tf.io.FixedLenFeature( 51 | length, DTYPE_MAP[dtype], [DEFAULTS_MAP[dtype]] * length 52 | ) 53 | for feature_name in used_features: 54 | if feature_name not in tfe_schema: 55 | raise ValueError(f"{feature_name} missing from schema: {segdense_config.schema_path}.") 56 | return tfe_schema 57 | 58 | 59 | @functools.lru_cache(1) 60 | def make_mantissa_mask(mask_length: int) -> tf.Tensor: 61 | """For experimentating with emulating bfloat16 or less precise types.""" 62 | return tf.constant((1 << 32) - (1 << mask_length), dtype=tf.int32) 63 | 64 | 65 | def mask_mantissa(tensor: tf.Tensor, mask_length: int) -> tf.Tensor: 66 | """For experimentating with emulating bfloat16 or less precise types.""" 67 | mask: tf.Tensor = make_mantissa_mask(mask_length) 68 | return tf.bitcast(tf.bitwise.bitwise_and(tf.bitcast(tensor, tf.int32), mask), tensor.dtype) 69 | 70 | 71 | def parse_tf_example( 72 | serialized_example, 73 | tfe_schema, 74 | seg_dense_schema_config, 75 | ): 76 | """Parse serialized tf.Example into dict of tensors. 77 | 78 | Args: 79 | serialized_example: Serialized tf.Example to be parsed. 80 | tfe_schema: Dictionary schema suitable for deserializing tf.Example. 81 | 82 | Returns: 83 | Dictionary of tensors to be used as model input. 84 | """ 85 | inputs = tf.io.parse_example(serialized=serialized_example, features=tfe_schema) 86 | 87 | for new_feature_name, old_feature_name in seg_dense_schema_config.renamed_features.items(): 88 | inputs[new_feature_name] = inputs.pop(old_feature_name) 89 | 90 | # This should not actually be used except for experimentation with low precision floats. 91 | if "mask_mantissa_features" in seg_dense_schema_config: 92 | for feature_name, mask_length in seg_dense_schema_config.mask_mantissa_features.items(): 93 | inputs[feature_name] = mask_mantissa(inputs[feature_name], mask_length) 94 | 95 | # DANGER DANGER: This default seems really scary, and it's only here because it has to be visible 96 | # at TF level. 97 | # We should not return empty tensors if we dont use embeddings. 98 | # Otherwise, it breaks numpy->pt conversion 99 | renamed_keys = list(seg_dense_schema_config.renamed_features.keys()) 100 | for renamed_key in renamed_keys: 101 | if "embedding" in renamed_key and (renamed_key not in inputs): 102 | inputs[renamed_key] = tf.zeros([], tf.float32) 103 | 104 | logging.info(f"parsed example and inputs are {inputs}") 105 | return inputs 106 | 107 | 108 | def get_seg_dense_parse_fn(data_config: recap_data_config.RecapDataConfig): 109 | """Placeholder for seg dense. 110 | 111 | In the future, when we use more seg dense variations, we can change this. 112 | """ 113 | with tf.io.gfile.GFile(data_config.seg_dense_schema.schema_path, "r") as f: 114 | seg_dense_schema = json.load(f)["schema"] 115 | 116 | tf_example_schema = create_tf_example_schema( 117 | data_config, 118 | seg_dense_schema, 119 | ) 120 | 121 | logging.info("***** TF Example Schema *****") 122 | logging.info(tf_example_schema) 123 | 124 | parse = functools.partial( 125 | parse_tf_example, 126 | tfe_schema=tf_example_schema, 127 | seg_dense_schema_config=data_config.seg_dense_schema, 128 | ) 129 | return parse 130 | -------------------------------------------------------------------------------- /reader/dataset.py: -------------------------------------------------------------------------------- 1 | """Dataset to be overwritten that can work with or without distributed reading. 2 | 3 | - Override `pa_to_batch` for dataset specific imputation, negative sampling, or coercion to Batch. 4 | - Readers can be colocated or off trainer machines. 5 | 6 | """ 7 | import abc 8 | import functools 9 | import random 10 | from typing import Optional 11 | 12 | from fsspec.implementations.local import LocalFileSystem 13 | import pyarrow.dataset as pads 14 | import pyarrow as pa 15 | import pyarrow.parquet 16 | import pyarrow.flight 17 | from pyarrow.ipc import IpcWriteOptions 18 | import torch 19 | 20 | from tml.common.batch import DataclassBatch 21 | from tml.machines import environment as env 22 | import tml.reader.utils as reader_utils 23 | from tml.common.filesystem import infer_fs 24 | from tml.ml_logging.torch_logging import logging 25 | 26 | 27 | class _Reader(pa.flight.FlightServerBase): 28 | """Distributed reader flight server wrapping a dataset.""" 29 | 30 | def __init__(self, location: str, ds: "Dataset"): 31 | super().__init__(location=location) 32 | self._location = location 33 | self._ds = ds 34 | 35 | def do_get(self, _, __): 36 | # NB: An updated schema (to account for column selection) has to be given the stream. 37 | schema = next(iter(self._ds.to_batches())).schema 38 | batches = self._ds.to_batches() 39 | return pa.flight.RecordBatchStream( 40 | data_source=pa.RecordBatchReader.from_batches( 41 | schema=schema, 42 | batches=batches, 43 | ), 44 | options=IpcWriteOptions(use_threads=True), 45 | ) 46 | 47 | 48 | class Dataset(torch.utils.data.IterableDataset): 49 | LOCATION = "grpc://0.0.0.0:2222" 50 | 51 | def __init__(self, file_pattern: str, **dataset_kwargs) -> None: 52 | """Specify batch size and column to select for. 53 | 54 | Refer to https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner.from_dataset. 55 | """ 56 | self._file_pattern = file_pattern 57 | self._fs = infer_fs(self._file_pattern) 58 | self._dataset_kwargs = dataset_kwargs 59 | logging.info(f"Using dataset_kwargs: {self._dataset_kwargs}") 60 | self._files = self._fs.glob(self._file_pattern) 61 | assert len(self._files) > 0, f"No files found at {self._file_pattern}" 62 | logging.info(f"Found {len(self._files)} files: {', '.join(self._files[:4])}, ...") 63 | self._schema = pa.parquet.read_schema(self._files[0], filesystem=self._fs) 64 | self._validate_columns() 65 | 66 | def _validate_columns(self): 67 | columns = set(self._dataset_kwargs.get("columns", [])) 68 | wrong_columns = set(columns) - set(self._schema.names) 69 | if wrong_columns: 70 | raise Exception(f"Specified columns {list(wrong_columns)} not in schema.") 71 | 72 | def serve(self): 73 | self.reader = _Reader(location=self.LOCATION, ds=self) 74 | self.reader.serve() 75 | 76 | def _create_dataset(self): 77 | return pads.dataset( 78 | source=random.sample(self._files, len(self._files))[0], 79 | format="parquet", 80 | filesystem=self._fs, 81 | exclude_invalid_files=False, 82 | ) 83 | 84 | def to_batches(self): 85 | """This allows the init to control reading settings. 86 | 87 | Refer to https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner.from_dataset. 88 | 89 | Perform `drop_remainder` behavior to afix batch size. 90 | This does not shift our data distribution bc of volume and file-level shuffling on every repeat. 91 | """ 92 | batch_size = self._dataset_kwargs["batch_size"] 93 | while True: 94 | ds = self._create_dataset() 95 | for batch in ds.to_batches(**self._dataset_kwargs): 96 | if batch.num_rows < batch_size: 97 | logging.info(f"Dropping remainder ({batch.num_rows}/{batch_size})") 98 | break 99 | yield batch 100 | 101 | @abc.abstractmethod 102 | def pa_to_batch(self, batch: pa.RecordBatch) -> DataclassBatch: 103 | raise NotImplementedError 104 | 105 | def dataloader(self, remote: bool = False): 106 | if not remote: 107 | return map(self.pa_to_batch, self.to_batches()) 108 | readers = get_readers(2) 109 | return map(self.pa_to_batch, reader_utils.roundrobin(*readers)) 110 | 111 | 112 | GRPC_OPTIONS = [ 113 | ("GRPC_ARG_KEEPALIVE_TIME_MS", 60000), 114 | ("GRPC_ARG_MIN_RECONNECT_BACKOFF_MS", 2000), 115 | ("GRPC_ARG_MAX_METADATA_SIZE", 1024 * 1024 * 1024), 116 | ] 117 | 118 | 119 | def get_readers(num_readers_per_worker: int): 120 | addresses = env.get_flight_server_addresses() 121 | 122 | readers = [] 123 | for worker in addresses: 124 | logging.info(f"Attempting connection to reader {worker}.") 125 | client = pa.flight.connect(worker, generic_options=GRPC_OPTIONS) 126 | client.wait_for_available(60) 127 | reader = client.do_get(None).to_reader() 128 | logging.info(f"Connected reader to {worker}.") 129 | readers.append(reader) 130 | return readers 131 | -------------------------------------------------------------------------------- /core/test_metrics.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from tml.core import metrics as core_metrics 4 | from tml.core.metric_mixin import MetricMixin, prepend_transform 5 | 6 | import torch 7 | from torchmetrics import MaxMetric, MetricCollection, SumMetric 8 | 9 | 10 | @dataclass 11 | class MockStratifierConfig: 12 | name: str 13 | index: int 14 | value: int 15 | 16 | 17 | class Count(MetricMixin, SumMetric): 18 | def transform(self, outputs): 19 | return {"value": 1} 20 | 21 | 22 | Max = prepend_transform(MaxMetric, lambda outputs: {"value": outputs["value"]}) 23 | 24 | 25 | def test_count_metric(): 26 | num_examples = 123 27 | examples = [ 28 | {"stuff": 0}, 29 | ] * num_examples 30 | 31 | metric = Count() 32 | for outputs in examples: 33 | metric.update(outputs) 34 | 35 | assert metric.compute().item() == num_examples 36 | 37 | 38 | def test_collections(): 39 | max_metric = Max() 40 | count_metric = Count() 41 | metric = MetricCollection([max_metric, count_metric]) 42 | 43 | examples = [{"value": idx} for idx in range(123)] 44 | for outputs in examples: 45 | metric.update(outputs) 46 | 47 | assert metric.compute() == { 48 | max_metric.__class__.__name__: len(examples) - 1, 49 | count_metric.__class__.__name__: len(examples), 50 | } 51 | 52 | 53 | def test_task_dependent_ctr(): 54 | num_examples = 144 55 | batch_size = 1024 56 | outputs = [ 57 | { 58 | "stuff": 0, 59 | "labels": torch.arange(0, 6).repeat(batch_size, 1), 60 | } 61 | for idx in range(num_examples) 62 | ] 63 | 64 | for task_idx in range(5): 65 | metric = core_metrics.Ctr(task_idx=task_idx) 66 | for output in outputs: 67 | metric.update(output) 68 | assert metric.compute().item() == task_idx 69 | 70 | 71 | def test_stratified_ctr(): 72 | outputs = [ 73 | { 74 | "stuff": 0, 75 | # [bsz, tasks] 76 | "labels": torch.tensor( 77 | [ 78 | [0, 1, 2, 3], 79 | [1, 2, 3, 4], 80 | [2, 3, 4, 0], 81 | ] 82 | ), 83 | "stratifiers": { 84 | # [bsz] 85 | "level": torch.tensor( 86 | [9, 0, 9], 87 | ), 88 | }, 89 | } 90 | ] 91 | 92 | stratifier = MockStratifierConfig(name="level", index=2, value=9) 93 | for task_idx in range(5): 94 | metric = core_metrics.Ctr(task_idx=1, stratifier=stratifier) 95 | for output in outputs: 96 | metric.update(output) 97 | # From the dataset of: 98 | # [ 99 | # [0, 1, 2, 3], 100 | # [1, 2, 3, 4], 101 | # [2, 3, 4, 0], 102 | # ] 103 | # we pick out 104 | # [ 105 | # [0, 1, 2, 3], 106 | # [2, 3, 4, 0], 107 | # ] 108 | # and with Ctr task_idx, we pick out 109 | # [ 110 | # [1,], 111 | # [3,], 112 | # ] 113 | assert metric.compute().item() == (1 + 3) / 2 114 | 115 | 116 | def test_auc(): 117 | num_samples = 10000 118 | metric = core_metrics.Auc(num_samples) 119 | target = torch.tensor([0, 0, 1, 1, 1]) 120 | preds_correct = torch.tensor([-1.0, -1.0, 1.0, 1.0, 1.0]) 121 | outputs_correct = {"logits": preds_correct, "labels": target} 122 | preds_bad = torch.tensor([1.0, 1.0, -1.0, -1.0, -1.0]) 123 | outputs_bad = {"logits": preds_bad, "labels": target} 124 | 125 | metric.update(outputs_correct) 126 | assert metric.compute().item() == 1.0 127 | 128 | metric.reset() 129 | metric.update(outputs_bad) 130 | assert metric.compute().item() == 0.0 131 | 132 | 133 | def test_pos_rank(): 134 | metric = core_metrics.PosRanks() 135 | target = torch.tensor([0, 0, 1, 1, 1]) 136 | preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5]) 137 | outputs_correct = {"logits": preds_correct, "labels": target} 138 | preds_bad = torch.tensor([1.0, 1.0, -1.5, -1.0, -0.5]) 139 | outputs_bad = {"logits": preds_bad, "labels": target} 140 | 141 | metric.update(outputs_correct) 142 | assert metric.compute().item() == 2.0 143 | 144 | metric.reset() 145 | metric.update(outputs_bad) 146 | assert metric.compute().item() == 4.0 147 | 148 | 149 | def test_reciprocal_rank(): 150 | metric = core_metrics.ReciprocalRank() 151 | target = torch.tensor([0, 0, 1, 1, 1]) 152 | preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5]) 153 | outputs_correct = {"logits": preds_correct, "labels": target} 154 | preds_bad = torch.tensor([1.0, 1.0, -1.5, -1.0, -0.5]) 155 | outputs_bad = {"logits": preds_bad, "labels": target} 156 | 157 | metric.update(outputs_correct) 158 | assert abs(metric.compute().item() - 0.6111) < 0.001 159 | 160 | metric.reset() 161 | metric.update(outputs_bad) 162 | assert abs(metric.compute().item() == 0.2611) < 0.001 163 | 164 | 165 | def test_hit_k(): 166 | hit1_metric = core_metrics.HitAtK(1) 167 | target = torch.tensor([0, 0, 1, 1, 1]) 168 | preds_correct = torch.tensor([-1.0, 1.0, 0.5, -0.1, 1.5]) 169 | outputs_correct = {"logits": preds_correct, "labels": target} 170 | preds_bad = torch.tensor([1.0, 1.0, -1.5, -1.0, -0.5]) 171 | outputs_bad = {"logits": preds_bad, "labels": target} 172 | 173 | hit1_metric.update(outputs_correct) 174 | assert abs(hit1_metric.compute().item() - 0.3333) < 0.0001 175 | 176 | hit1_metric.reset() 177 | hit1_metric.update(outputs_bad) 178 | 179 | assert hit1_metric.compute().item() == 0 180 | 181 | hit3_metric = core_metrics.HitAtK(3) 182 | hit3_metric.update(outputs_correct) 183 | assert (hit3_metric.compute().item() - 0.66666) < 0.0001 184 | 185 | hit3_metric.reset() 186 | hit3_metric.update(outputs_bad) 187 | assert abs(hit3_metric.compute().item() - 0.3333) < 0.0001 188 | -------------------------------------------------------------------------------- /core/metrics.py: -------------------------------------------------------------------------------- 1 | """Common metrics that also support multi task. 2 | 3 | We assume multi task models will output [task_idx, ...] predictions 4 | 5 | """ 6 | from typing import Any, Dict 7 | 8 | from tml.core.metric_mixin import MetricMixin, StratifyMixin, TaskMixin 9 | 10 | import torch 11 | import torchmetrics as tm 12 | 13 | 14 | def probs_and_labels( 15 | outputs: Dict[str, torch.Tensor], 16 | task_idx: int, 17 | ) -> Dict[str, torch.Tensor]: 18 | preds = outputs["probabilities"] 19 | target = outputs["labels"] 20 | if task_idx >= 0: 21 | preds = preds[:, task_idx] 22 | target = target[:, task_idx] 23 | return { 24 | "preds": preds, 25 | "target": target.int(), 26 | } 27 | 28 | 29 | class Count(StratifyMixin, TaskMixin, MetricMixin, tm.SumMetric): 30 | def transform(self, outputs): 31 | outputs = self.maybe_apply_stratification(outputs, ["labels"]) 32 | value = outputs["labels"] 33 | if self._task_idx >= 0: 34 | value = value[:, self._task_idx] 35 | return {"value": value} 36 | 37 | 38 | class Ctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): 39 | def transform(self, outputs): 40 | outputs = self.maybe_apply_stratification(outputs, ["labels"]) 41 | value = outputs["labels"] 42 | if self._task_idx >= 0: 43 | value = value[:, self._task_idx] 44 | return {"value": value} 45 | 46 | 47 | class Pctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): 48 | def transform(self, outputs): 49 | outputs = self.maybe_apply_stratification(outputs, ["probabilities"]) 50 | value = outputs["probabilities"] 51 | if self._task_idx >= 0: 52 | value = value[:, self._task_idx] 53 | return {"value": value} 54 | 55 | 56 | class Precision(StratifyMixin, TaskMixin, MetricMixin, tm.Precision): 57 | def transform(self, outputs): 58 | outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"]) 59 | return probs_and_labels(outputs, self._task_idx) 60 | 61 | 62 | class Recall(StratifyMixin, TaskMixin, MetricMixin, tm.Recall): 63 | def transform(self, outputs): 64 | outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"]) 65 | return probs_and_labels(outputs, self._task_idx) 66 | 67 | 68 | class TorchMetricsRocauc(StratifyMixin, TaskMixin, MetricMixin, tm.AUROC): 69 | def transform(self, outputs): 70 | outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"]) 71 | return probs_and_labels(outputs, self._task_idx) 72 | 73 | 74 | class Auc(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): 75 | """ 76 | Based on: 77 | https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/util.py#L420 78 | """ 79 | 80 | def __init__(self, num_samples, **kwargs): 81 | super().__init__(**kwargs) 82 | self.num_samples = num_samples 83 | 84 | def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]: 85 | scores, labels = outputs["logits"], outputs["labels"] 86 | pos_scores = scores[labels == 1] 87 | neg_scores = scores[labels == 0] 88 | result = { 89 | "value": pos_scores[torch.randint(len(pos_scores), (self.num_samples,))] 90 | > neg_scores[torch.randint(len(neg_scores), (self.num_samples,))] 91 | } 92 | return result 93 | 94 | 95 | class PosRanks(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): 96 | """ 97 | The ranks of all positives 98 | Based on: 99 | https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L73 100 | """ 101 | 102 | def __init__(self, **kwargs): 103 | super().__init__(**kwargs) 104 | 105 | def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]: 106 | scores, labels = outputs["logits"], outputs["labels"] 107 | _, sorted_indices = scores.sort(descending=True) 108 | pos_ranks = labels[sorted_indices].nonzero(as_tuple=True)[0] + 1 # all ranks start from 1 109 | result = {"value": pos_ranks} 110 | return result 111 | 112 | 113 | class ReciprocalRank(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): 114 | """ 115 | The reciprocal of the ranks of all 116 | Based on: 117 | https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L74 118 | """ 119 | 120 | def __init__(self, **kwargs): 121 | super().__init__(**kwargs) 122 | 123 | def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]: 124 | scores, labels = outputs["logits"], outputs["labels"] 125 | _, sorted_indices = scores.sort(descending=True) 126 | pos_ranks = labels[sorted_indices].nonzero(as_tuple=True)[0] + 1 # all ranks start from 1 127 | result = {"value": torch.div(torch.ones_like(pos_ranks), pos_ranks)} 128 | return result 129 | 130 | 131 | class HitAtK(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): 132 | """ 133 | The fraction of positives that rank in the top K among their negatives 134 | Note that this is basically precision@k 135 | Based on: 136 | https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L75 137 | """ 138 | 139 | def __init__(self, k: int, **kwargs): 140 | super().__init__(**kwargs) 141 | self.k = k 142 | 143 | def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]: 144 | scores, labels = outputs["logits"], outputs["labels"] 145 | _, sorted_indices = scores.sort(descending=True) 146 | pos_ranks = labels[sorted_indices].nonzero(as_tuple=True)[0] + 1 # all ranks start from 1 147 | result = {"value": (pos_ranks <= self.k).float()} 148 | return result 149 | -------------------------------------------------------------------------------- /projects/twhin/data/edges.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from dataclasses import dataclass 3 | from typing import Dict, List, Tuple 4 | 5 | from tml.common.batch import DataclassBatch 6 | from tml.reader.dataset import Dataset 7 | from tml.projects.twhin.models.config import Relation 8 | 9 | import numpy as np 10 | import pyarrow as pa 11 | import pyarrow.compute as pc 12 | import torch 13 | from torchrec.sparse.jagged_tensor import KeyedJaggedTensor 14 | 15 | 16 | @dataclass 17 | class EdgeBatch(DataclassBatch): 18 | nodes: KeyedJaggedTensor 19 | labels: torch.Tensor 20 | rels: torch.Tensor 21 | weights: torch.Tensor 22 | 23 | 24 | class EdgesDataset(Dataset): 25 | rng = np.random.default_rng() 26 | 27 | def __init__( 28 | self, 29 | file_pattern: str, 30 | table_sizes: Dict[str, int], 31 | relations: List[Relation], 32 | lhs_column_name: str = "lhs", 33 | rhs_column_name: str = "rhs", 34 | rel_column_name: str = "rel", 35 | **dataset_kwargs 36 | ): 37 | self.batch_size = dataset_kwargs["batch_size"] 38 | 39 | self.table_sizes = table_sizes 40 | self.num_tables = len(table_sizes) 41 | self.table_names = list(table_sizes.keys()) 42 | 43 | self.relations = relations 44 | self.relations_t = torch.tensor( 45 | [ 46 | [self.table_names.index(relation.lhs), self.table_names.index(relation.rhs)] 47 | for relation in relations 48 | ] 49 | ) 50 | 51 | self.lhs_column_name = lhs_column_name 52 | self.rhs_column_name = rhs_column_name 53 | self.rel_column_name = rel_column_name 54 | self.label_column_name = "label" 55 | 56 | super().__init__(file_pattern=file_pattern, **dataset_kwargs) 57 | 58 | def pa_to_batch(self, batch: pa.RecordBatch): 59 | lhs = torch.from_numpy(batch.column(self.lhs_column_name).to_numpy()) 60 | rhs = torch.from_numpy(batch.column(self.rhs_column_name).to_numpy()) 61 | rel = torch.from_numpy(batch.column(self.rel_column_name).to_numpy()) 62 | label = torch.from_numpy(batch.column(self.label_column_name).to_numpy()) 63 | 64 | nodes = self._to_kjt(lhs, rhs, rel) 65 | return EdgeBatch( 66 | nodes=nodes, 67 | rels=rel, 68 | labels=label, 69 | weights=torch.ones(batch.num_rows), 70 | ) 71 | 72 | def _to_kjt( 73 | self, lhs: torch.Tensor, rhs: torch.Tensor, rel: torch.Tensor 74 | ) -> Tuple[KeyedJaggedTensor, List[Tuple[int, int]]]: 75 | 76 | """Process edges that contain lhs index, rhs index, relation index. 77 | Example: 78 | 79 | ``` 80 | tables = ["f0", "f1", "f2", "f3"] 81 | relations = [["f0", "f1"], ["f1", "f2"], ["f1", "f0"], ["f2", "f1"], ["f0", "f2"]] 82 | self.relations_t = torch.Tensor([[0, 1], [1, 2], [1, 0], [2, 1], [0, 2]]) 83 | lhs = [1, 6, 3, 1, 8] 84 | rhs = [6, 3, 4, 4, 9] 85 | rel = [0, 2, 1, 3, 4] 86 | 87 | This corresponds to the following "edges": 88 | edges = [ 89 | {"lhs": 1, "rhs": 6, "relation": ["f0", "f1"]}, 90 | {"lhs": 6, "rhs": 3, "relation": ["f1", "f0"]}, 91 | {"lhs": 3, "rhs": 4, "relation": ["f1", "f2"]}, 92 | {"lhs": 1, "rhs": 4, "relation": ["f2", "f1"]}, 93 | {"lhs": 8, "rhs": 9, "relation": ["f0", "f2"]}, 94 | ] 95 | ``` 96 | 97 | Returns a KeyedJaggedTensor used to look up all embeddings. 98 | 99 | Note: We treat the lhs and rhs as though they're separate lookups: `len(lenghts) == 2 * bsz * len(tables)`. 100 | This differs from the DLRM pattern where we have `len(lengths) = bsz * len(tables)`. 101 | 102 | For the example above: 103 | ``` 104 | lookups = tensor([ 105 | [0., 1.], 106 | [1., 6.], 107 | [1., 6.], 108 | [0., 3.], 109 | [1., 3.], 110 | [2., 4.], 111 | [2., 1.], 112 | [1., 4.], 113 | [0., 8.], 114 | [2., 9.] 115 | ]) 116 | 117 | kjt = KeyedJaggedTensor( 118 | features=["f0", "f1", "f2"] 119 | values=[ 120 | 1, 3, 8, # f0 121 | 6, 6, 3, 4, # f1 122 | 4, 1, 9 # f2 123 | ] 124 | lengths=[ 125 | 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, # f0 126 | 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, # f1 127 | 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, # f2 128 | ) 129 | ``` 130 | 131 | Note: 132 | - values = [values for f0] + [values for f1] + [values for f2] 133 | - lengths are always 0 or 1, and sum(lengths) = len(values) = 2 * bsz 134 | """ 135 | lookups = torch.concat((lhs[:, None], self.relations_t[rel], rhs[:, None]), dim=1) 136 | index = torch.LongTensor([1, 0, 2, 3]) 137 | lookups = lookups[:, index].reshape(2 * self.batch_size, 2) 138 | 139 | # values is just the row indices into each table, ordered by the table indices 140 | _, indices = torch.sort(lookups[:, 0], dim=0, stable=True) 141 | values = lookups[indices][:, 1].int() 142 | 143 | # lengths[table_idx * batch_size + i] == whether the ith lookup is for the table with index table_idx 144 | lengths = torch.arange(self.num_tables)[:, None].eq(lookups[:, 0]) 145 | lengths = lengths.reshape(-1).int() 146 | 147 | return KeyedJaggedTensor(keys=self.table_names, values=values, lengths=lengths) 148 | 149 | def to_batches(self): 150 | ds = super().to_batches() 151 | batch_size = self._dataset_kwargs["batch_size"] 152 | 153 | names = [ 154 | self.lhs_column_name, 155 | self.rhs_column_name, 156 | self.rel_column_name, 157 | self.label_column_name, 158 | ] 159 | for _, batch in enumerate(ds): 160 | # Pass along positive edges 161 | lhs = batch.column(self.lhs_column_name) 162 | rhs = batch.column(self.rhs_column_name) 163 | rel = batch.column(self.rel_column_name) 164 | label = pa.array(np.ones(batch_size, dtype=np.int64)) 165 | 166 | yield pa.RecordBatch.from_arrays( 167 | arrays=[lhs, rhs, rel, label], 168 | names=names, 169 | ) 170 | -------------------------------------------------------------------------------- /projects/twhin/models/models.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | import math 3 | 4 | from tml.projects.twhin.data.edges import EdgeBatch 5 | from tml.projects.twhin.models.config import TwhinModelConfig 6 | from tml.projects.twhin.data.config import TwhinDataConfig 7 | from tml.common.modules.embedding.embedding import LargeEmbeddings 8 | from tml.optimizers.optimizer import get_optimizer_class 9 | from tml.optimizers.config import get_optimizer_algorithm_config 10 | 11 | import torch 12 | from torch import nn 13 | from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward 14 | 15 | 16 | class TwhinModel(nn.Module): 17 | def __init__(self, model_config: TwhinModelConfig, data_config: TwhinDataConfig): 18 | super().__init__() 19 | self.batch_size = data_config.per_replica_batch_size 20 | self.table_names = [table.name for table in model_config.embeddings.tables] 21 | self.large_embeddings = LargeEmbeddings(model_config.embeddings) 22 | self.embedding_dim = model_config.embeddings.tables[0].embedding_dim 23 | self.num_tables = len(model_config.embeddings.tables) 24 | self.in_batch_negatives = data_config.in_batch_negatives 25 | self.global_negatives = data_config.global_negatives 26 | self.num_relations = len(model_config.relations) 27 | 28 | # one bias per relation 29 | self.all_trans_embs = torch.nn.parameter.Parameter( 30 | torch.nn.init.uniform_(torch.empty(self.num_relations, self.embedding_dim)) 31 | ) 32 | 33 | def forward(self, batch: EdgeBatch): 34 | 35 | # B x D 36 | trans_embs = self.all_trans_embs.data[batch.rels] 37 | 38 | # KeyedTensor 39 | outs = self.large_embeddings(batch.nodes) 40 | 41 | # 2B x TD 42 | x = outs.values() 43 | 44 | # 2B x T x D 45 | x = x.reshape(2 * self.batch_size, -1, self.embedding_dim) 46 | 47 | # 2B x D 48 | x = torch.sum(x, 1) 49 | 50 | # B x 2 x D 51 | x = x.reshape(self.batch_size, 2, self.embedding_dim) 52 | 53 | # translated 54 | translated = x[:, 1, :] + trans_embs 55 | 56 | negs = [] 57 | if self.in_batch_negatives: 58 | # construct dot products for negatives via matmul 59 | for relation in range(self.num_relations): 60 | rel_mask = batch.rels == relation 61 | rel_count = rel_mask.sum() 62 | 63 | if not rel_count: 64 | continue 65 | 66 | # R x D 67 | lhs_matrix = x[rel_mask, 0, :] 68 | rhs_matrix = x[rel_mask, 1, :] 69 | 70 | lhs_perm = torch.randperm(lhs_matrix.shape[0]) 71 | # repeat until we have enough negatives 72 | lhs_perm = lhs_perm.repeat(math.ceil(float(self.in_batch_negatives) / rel_count)) 73 | lhs_indices = lhs_perm[: self.in_batch_negatives] 74 | sampled_lhs = lhs_matrix[lhs_indices] 75 | 76 | rhs_perm = torch.randperm(rhs_matrix.shape[0]) 77 | # repeat until we have enough negatives 78 | rhs_perm = rhs_perm.repeat(math.ceil(float(self.in_batch_negatives) / rel_count)) 79 | rhs_indices = rhs_perm[: self.in_batch_negatives] 80 | sampled_rhs = rhs_matrix[rhs_indices] 81 | 82 | # RS 83 | negs_rhs = torch.flatten(torch.matmul(lhs_matrix, sampled_rhs.t())) 84 | negs_lhs = torch.flatten(torch.matmul(rhs_matrix, sampled_lhs.t())) 85 | 86 | negs.append(negs_lhs) 87 | negs.append(negs_rhs) 88 | 89 | # dot product for positives 90 | x = (x[:, 0, :] * translated).sum(-1) 91 | 92 | # concat positives and negatives 93 | x = torch.cat([x, *negs]) 94 | return { 95 | "logits": x, 96 | "probabilities": torch.sigmoid(x), 97 | } 98 | 99 | 100 | def apply_optimizers(model: TwhinModel, model_config: TwhinModelConfig): 101 | for table in model_config.embeddings.tables: 102 | optimizer_class = get_optimizer_class(table.optimizer) 103 | optimizer_kwargs = get_optimizer_algorithm_config(table.optimizer).dict() 104 | params = [ 105 | param 106 | for name, param in model.large_embeddings.ebc.named_parameters() 107 | if (name.startswith(f"embedding_bags.{table.name}")) 108 | ] 109 | apply_optimizer_in_backward( 110 | optimizer_class=optimizer_class, 111 | params=params, 112 | optimizer_kwargs=optimizer_kwargs, 113 | ) 114 | 115 | return model 116 | 117 | 118 | class TwhinModelAndLoss(torch.nn.Module): 119 | def __init__( 120 | self, 121 | model, 122 | loss_fn: Callable, 123 | data_config: TwhinDataConfig, 124 | device: torch.device, 125 | ) -> None: 126 | """ 127 | Args: 128 | model: torch module to wrap. 129 | loss_fn: Function for calculating loss, should accept logits and labels. 130 | """ 131 | super().__init__() 132 | self.model = model 133 | self.loss_fn = loss_fn 134 | self.batch_size = data_config.per_replica_batch_size 135 | self.in_batch_negatives = data_config.in_batch_negatives 136 | self.device = device 137 | 138 | def forward(self, batch: "RecapBatch"): # type: ignore[name-defined] 139 | """Runs model forward and calculates loss according to given loss_fn. 140 | 141 | NOTE: The input signature here needs to be a Pipelineable object for 142 | prefetching purposes during training using torchrec's pipeline. However 143 | the underlying model signature needs to be exportable to onnx, requiring 144 | generic python types. see https://pytorch.org/docs/stable/onnx.html#types. 145 | 146 | """ 147 | outputs = self.model(batch) 148 | logits = outputs["logits"] 149 | 150 | num_negatives = 2 * self.batch_size * self.in_batch_negatives 151 | num_positives = self.batch_size 152 | 153 | neg_weight = float(num_positives) / num_negatives 154 | 155 | labels = torch.cat([batch.labels.float(), torch.ones(num_negatives).to(self.device)]) 156 | 157 | weights = torch.cat( 158 | [batch.weights.float(), (torch.ones(num_negatives) * neg_weight).to(self.device)] 159 | ) 160 | 161 | losses = self.loss_fn(logits, labels, weights) 162 | 163 | outputs.update( 164 | { 165 | "loss": losses, 166 | "labels": labels, 167 | "weights": weights, 168 | } 169 | ) 170 | 171 | # Allow multiple losses. 172 | return losses, outputs 173 | -------------------------------------------------------------------------------- /metrics/auroc.py: -------------------------------------------------------------------------------- 1 | """ 2 | AUROC metrics. 3 | """ 4 | from typing import Union 5 | 6 | from tml.ml_logging.torch_logging import logging 7 | 8 | import torch 9 | import torchmetrics 10 | from torchmetrics.utilities.data import dim_zero_cat 11 | 12 | 13 | def _compute_helper( 14 | predictions: torch.Tensor, 15 | target: torch.Tensor, 16 | weights: torch.Tensor, 17 | max_positive_negative_weighted_sum: torch.Tensor, 18 | min_positive_negative_weighted_sum: torch.Tensor, 19 | equal_predictions_as_incorrect: bool, 20 | ) -> torch.Tensor: 21 | """ 22 | Compute AUROC. 23 | Args: 24 | predictions: The predictions probabilities. 25 | target: The target. 26 | weights: The sample weights to assign to each sample in the batch. 27 | max_positive_negative_weighted_sum: The sum of the weights for the positive labels. 28 | min_positive_negative_weighted_sum: 29 | equal_predictions_as_incorrect: For positive & negative labels having identical scores, 30 | we assume that they are correct prediction (i.e weight = 1) when ths is False. Otherwise, 31 | we assume that they are correct prediction (i.e weight = 0). 32 | """ 33 | dim = 0 34 | 35 | # Sort predictions based on key (score, true_label). The order is ascending for score. 36 | # For true_label, order is ascending if equal_predictions_as_incorrect is True; 37 | # otherwise it is descending. 38 | target_order = torch.argsort(target, dim=dim, descending=equal_predictions_as_incorrect) 39 | score_order = torch.sort(torch.gather(predictions, dim, target_order), stable=True, dim=dim)[1] 40 | score_order = torch.gather(target_order, dim, score_order) 41 | sorted_target = torch.gather(target, dim, score_order) 42 | sorted_weights = torch.gather(weights, dim, score_order) 43 | 44 | negatives_from_left = torch.cumsum((1.0 - sorted_target) * sorted_weights, 0) 45 | 46 | numerator = torch.sum( 47 | sorted_weights * (sorted_target * negatives_from_left / max_positive_negative_weighted_sum) 48 | ) 49 | 50 | return numerator / min_positive_negative_weighted_sum 51 | 52 | 53 | class AUROCWithMWU(torchmetrics.Metric): 54 | """ 55 | AUROC using Mann-Whitney U-test. 56 | See https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve. 57 | 58 | This AUROC implementation is well suited to (non-zero) low-CTR. In particular it will return 59 | the correct AUROC even if the predicted probabilities are all close to 0. 60 | Currently only support binary classification. 61 | """ 62 | 63 | def __init__(self, label_threshold: float = 0.5, raise_missing_class: bool = False, **kwargs): 64 | """ 65 | 66 | Args: 67 | label_threshold: Labels strictly above this threshold are considered positive labels, 68 | otherwise, they are considered negative. 69 | raise_missing_class: If True, an error will be raise if negative or positive class is missing. 70 | Otherwise, we will simply log a warning. 71 | **kwargs: Additional parameters supported by all torchmetrics.Metric. 72 | """ 73 | super().__init__(**kwargs) 74 | self.add_state("predictions", default=[], dist_reduce_fx="cat") 75 | self.add_state("target", default=[], dist_reduce_fx="cat") 76 | self.add_state("weights", default=[], dist_reduce_fx="cat") 77 | 78 | self.label_threshold = label_threshold 79 | self.raise_missing_class = raise_missing_class 80 | 81 | def update( 82 | self, 83 | predictions: torch.Tensor, 84 | target: torch.Tensor, 85 | weight: Union[float, torch.Tensor] = 1.0, 86 | ) -> None: 87 | """ 88 | Update the current auroc. 89 | Args: 90 | predictions: Predicted values, 1D Tensor or 2D Tensor of shape batch_size x 1. 91 | target: Ground truth. Must have same shape as predictions. 92 | weight: The weight to use for the predicted values. Shape should be 93 | broadcastable to that of predictions. 94 | """ 95 | self.predictions.append(predictions) 96 | self.target.append(target) 97 | if not isinstance(weight, torch.Tensor): 98 | weight = torch.as_tensor(weight, dtype=predictions.dtype, device=target.device) 99 | self.weights.append(torch.broadcast_to(weight, predictions.size())) 100 | 101 | def compute(self) -> torch.Tensor: 102 | """ 103 | Compute and return the accumulated AUROC. 104 | """ 105 | weights = dim_zero_cat(self.weights) 106 | predictions = dim_zero_cat(self.predictions) 107 | target = dim_zero_cat(self.target).type_as(predictions) 108 | 109 | negative_mask = target <= self.label_threshold 110 | positive_mask = torch.logical_not(negative_mask) 111 | 112 | if not negative_mask.any(): 113 | msg = "Negative class missing. AUROC returned will be meaningless." 114 | if self.raise_missing_class: 115 | raise ValueError(msg) 116 | else: 117 | logging.warn(msg) 118 | if not positive_mask.any(): 119 | msg = "Positive class missing. AUROC returned will be meaningless." 120 | if self.raise_missing_class: 121 | raise ValueError(msg) 122 | else: 123 | logging.warn(msg) 124 | 125 | weighted_actual_negative_sum = torch.sum( 126 | torch.where(negative_mask, weights, torch.zeros_like(weights)) 127 | ) 128 | 129 | weighted_actual_positive_sum = torch.sum( 130 | torch.where(positive_mask, weights, torch.zeros_like(weights)) 131 | ) 132 | 133 | max_positive_negative_weighted_sum = torch.max( 134 | weighted_actual_negative_sum, weighted_actual_positive_sum 135 | ) 136 | 137 | min_positive_negative_weighted_sum = torch.min( 138 | weighted_actual_negative_sum, weighted_actual_positive_sum 139 | ) 140 | 141 | # Compute auroc with the weight set to 1 when positive & negative have identical scores. 142 | auroc_le = _compute_helper( 143 | target=target, 144 | weights=weights, 145 | predictions=predictions, 146 | min_positive_negative_weighted_sum=min_positive_negative_weighted_sum, 147 | max_positive_negative_weighted_sum=max_positive_negative_weighted_sum, 148 | equal_predictions_as_incorrect=False, 149 | ) 150 | 151 | # Compute auroc with the weight set to 0 when positive & negative have identical scores. 152 | auroc_lt = _compute_helper( 153 | target=target, 154 | weights=weights, 155 | predictions=predictions, 156 | min_positive_negative_weighted_sum=min_positive_negative_weighted_sum, 157 | max_positive_negative_weighted_sum=max_positive_negative_weighted_sum, 158 | equal_predictions_as_incorrect=True, 159 | ) 160 | 161 | # Compute auroc with the weight set to 1/2 when positive & negative have identical scores. 162 | return auroc_le - (auroc_le - auroc_lt) / 2.0 163 | -------------------------------------------------------------------------------- /projects/home/recap/optimizer/optimizer.py: -------------------------------------------------------------------------------- 1 | """Build optimizers and learning rate schedules.""" 2 | import bisect 3 | from collections import defaultdict 4 | import functools 5 | import math 6 | import typing 7 | from typing import Optional 8 | import warnings 9 | 10 | # from large_embeddings.config import EmbeddingOptimizerConfig 11 | from tml.projects.home.recap import model as model_mod 12 | from tml.optimizers import config 13 | from tml.optimizers import compute_lr 14 | from absl import logging # type: ignore[attr-defined] 15 | 16 | import torch 17 | from torchrec.optim import keyed 18 | 19 | 20 | _DEFAULT_LR = 24601.0 # NaN the model if we're not using the learning rate. 21 | _BACKBONE = "backbone" 22 | _DENSE_EMBEDDINGS = "dense_ebc" 23 | 24 | 25 | class RecapLRShim(torch.optim.lr_scheduler._LRScheduler): 26 | """Shim to get learning rates into a LRScheduler. 27 | 28 | This adheres to the torch.optim scheduler API and can be plugged anywhere that 29 | e.g. exponential decay can be used. 30 | 31 | """ 32 | 33 | def __init__( 34 | self, 35 | optimizer, 36 | lr_dict: typing.Dict[str, config.LearningRate], 37 | emb_learning_rate, 38 | last_epoch=-1, 39 | verbose=False, 40 | ): 41 | self.optimizer = optimizer 42 | self.lr_dict = lr_dict 43 | self.group_names = list(self.lr_dict.keys()) 44 | self.emb_learning_rate = emb_learning_rate 45 | 46 | # We handle sparse LR scheduling separately, so only validate LR groups against dense param groups 47 | num_dense_param_groups = sum( 48 | 1 49 | for _, _optim in optimizer._optims 50 | for _ in _optim.param_groups 51 | if isinstance(_optim, keyed.KeyedOptimizerWrapper) 52 | ) 53 | if num_dense_param_groups != len(lr_dict): 54 | raise ValueError( 55 | f"Optimizer had {len(optimizer.param_groups)}, but config had {len(lr_dict)}." 56 | ) 57 | super().__init__(optimizer, last_epoch, verbose) 58 | 59 | def get_lr(self): 60 | if not self._get_lr_called_within_step: 61 | warnings.warn( 62 | "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", 63 | UserWarning, 64 | ) 65 | return self._get_closed_form_lr() 66 | 67 | def _get_closed_form_lr(self): 68 | learning_rates = [] 69 | 70 | for lr_config in self.lr_dict.values(): 71 | learning_rates.append(compute_lr(lr_config, self.last_epoch)) 72 | # WARNING: The order of appending is important. 73 | if self.emb_learning_rate: 74 | learning_rates.append(compute_lr(self.emb_learning_rate, self.last_epoch)) 75 | return learning_rates 76 | 77 | 78 | def build_optimizer( 79 | model: torch.nn.Module, 80 | optimizer_config: config.OptimizerConfig, 81 | emb_optimizer_config: None = None, # Optional[EmbeddingOptimizerConfig] = None, 82 | ): 83 | """Builds an optimizer and scheduler. 84 | 85 | Args: 86 | model: A torch model, probably with DDP/DMP. 87 | optimizer_config: An OptimizerConfig object that specifies learning rates per tower. 88 | 89 | Returns: 90 | A torch.optim instance, and a scheduler instance. 91 | """ 92 | optimizer_fn = functools.partial( 93 | torch.optim.Adam, 94 | lr=_DEFAULT_LR, 95 | betas=(optimizer_config.adam.beta_1, optimizer_config.adam.beta_2), 96 | eps=optimizer_config.adam.epsilon, 97 | maximize=False, 98 | ) 99 | if optimizer_config.multi_task_learning_rates: 100 | logging.info("***** Parameter groups for optimization *****") 101 | # Importantly, we preserve insertion order in dictionaries here. 102 | parameter_groups: typing.Dict[str, typing.Dict] = defaultdict(dict) 103 | added_parameters: typing.Set[str] = set() 104 | for task in optimizer_config.multi_task_learning_rates.tower_learning_rates: 105 | for name, parameter in model.named_parameters(): 106 | if f".{model_mod.sanitize(task)}." in name: 107 | parameter_groups[task][name] = parameter 108 | logging.info(f"{task}: {name}") 109 | if name in added_parameters: 110 | raise ValueError(f"Parameter {name} matched multiple tasks.") 111 | added_parameters.add(name) 112 | 113 | for name, parameter in model.named_parameters(): 114 | if name not in added_parameters and "embedding_bags" not in name: 115 | parameter_groups[_BACKBONE][name] = parameter 116 | added_parameters.add(name) 117 | logging.info(f"{_BACKBONE}: {name}") 118 | 119 | for name, parameter in model.named_parameters(): 120 | if name not in added_parameters and "embedding_bags" in name: 121 | parameter_groups[_DENSE_EMBEDDINGS][name] = parameter 122 | logging.info(f"{_DENSE_EMBEDDINGS}: {name}") 123 | 124 | all_learning_rates = optimizer_config.multi_task_learning_rates.tower_learning_rates.copy() 125 | if optimizer_config.multi_task_learning_rates.backbone_learning_rate is not None: 126 | all_learning_rates[ 127 | _BACKBONE 128 | ] = optimizer_config.multi_task_learning_rates.backbone_learning_rate 129 | if _DENSE_EMBEDDINGS in parameter_groups and emb_optimizer_config: 130 | all_learning_rates[_DENSE_EMBEDDINGS] = emb_optimizer_config.learning_rate.copy() 131 | else: 132 | parameter_groups = dict(model.named_parameters()) 133 | all_learning_rates = {"single_task": optimizer_config.single_task_learning_rate} 134 | 135 | optimizers = [ 136 | keyed.KeyedOptimizerWrapper(param_group, optimizer_fn) 137 | for param_name, param_group in parameter_groups.items() 138 | if param_name != _DENSE_EMBEDDINGS 139 | ] 140 | # Making EBC optimizer to be SGD to match fused optimiser 141 | if _DENSE_EMBEDDINGS in parameter_groups: 142 | optimizers.append( 143 | keyed.KeyedOptimizerWrapper( 144 | parameter_groups[_DENSE_EMBEDDINGS], 145 | functools.partial(torch.optim.SGD, lr=_DEFAULT_LR, maximize=False, momentum=False), 146 | ) 147 | ) 148 | 149 | if not parameter_groups.keys() == all_learning_rates.keys(): 150 | raise ValueError("Learning rates do not match optimizers") 151 | 152 | # If the optimiser is dense, model.fused_optimizer will be empty (but not None) 153 | emb_learning_rate = None 154 | if hasattr(model, "fused_optimizer") and model.fused_optimizer.optimizers: 155 | logging.info(f"Model fused optimiser: {model.fused_optimizer}") 156 | optimizers.append(model.fused_optimizer) 157 | if emb_optimizer_config: 158 | emb_learning_rate = emb_optimizer_config.learning_rate.copy() 159 | else: 160 | raise ValueError("Fused kernel exists, but LR is not set") 161 | logging.info(f"***** Combining optimizers: {optimizers} *****") 162 | optimizer = keyed.CombinedOptimizer(optimizers) 163 | scheduler = RecapLRShim(optimizer, all_learning_rates, emb_learning_rate) 164 | logging.info(f"***** Combined optimizer after init: {optimizer} *****") 165 | 166 | return optimizer, scheduler 167 | -------------------------------------------------------------------------------- /projects/home/recap/data/preprocessors.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocessors applied on DDS workers in order to modify the dataset on the fly. 3 | Some of these preprocessors are also applied to the model at serving time. 4 | """ 5 | from tml.projects.home.recap import config as config_mod 6 | from absl import logging 7 | import tensorflow as tf 8 | import numpy as np 9 | 10 | 11 | class TruncateAndSlice(tf.keras.Model): 12 | """Class for truncating and slicing.""" 13 | 14 | def __init__(self, truncate_and_slice_config): 15 | super().__init__() 16 | self._truncate_and_slice_config = truncate_and_slice_config 17 | 18 | if self._truncate_and_slice_config.continuous_feature_mask_path: 19 | with tf.io.gfile.GFile( 20 | self._truncate_and_slice_config.continuous_feature_mask_path, "rb" 21 | ) as f: 22 | self._continuous_mask = np.load(f).nonzero()[0] 23 | logging.info(f"Slicing {np.sum(self._continuous_mask)} continuous features.") 24 | else: 25 | self._continuous_mask = None 26 | 27 | if self._truncate_and_slice_config.binary_feature_mask_path: 28 | with tf.io.gfile.GFile(self._truncate_and_slice_config.binary_feature_mask_path, "rb") as f: 29 | self._binary_mask = np.load(f).nonzero()[0] 30 | logging.info(f"Slicing {np.sum(self._binary_mask)} binary features.") 31 | else: 32 | self._binary_mask = None 33 | 34 | def call(self, inputs, training=None, mask=None): 35 | outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs)) 36 | if self._truncate_and_slice_config.continuous_feature_truncation: 37 | logging.info("Truncating continuous") 38 | outputs["continuous"] = outputs["continuous"][ 39 | :, : self._truncate_and_slice_config.continuous_feature_truncation 40 | ] 41 | if self._truncate_and_slice_config.binary_feature_truncation: 42 | logging.info("Truncating binary") 43 | outputs["binary"] = outputs["binary"][ 44 | :, : self._truncate_and_slice_config.binary_feature_truncation 45 | ] 46 | if self._continuous_mask is not None: 47 | outputs["continuous"] = tf.gather(outputs["continuous"], self._continuous_mask, axis=1) 48 | if self._binary_mask is not None: 49 | outputs["binary"] = tf.gather(outputs["binary"], self._binary_mask, axis=1) 50 | return outputs 51 | 52 | 53 | class DownCast(tf.keras.Model): 54 | """Class for Down casting dataset before serialization and transferring to training host. 55 | Depends on the data type and the actual data range, the down casting can be lossless or not. 56 | It is strongly recommended to compare the metrics before and after down casting. 57 | """ 58 | 59 | def __init__(self, downcast_config): 60 | super().__init__() 61 | self.config = downcast_config 62 | self._type_map = { 63 | "bfloat16": tf.bfloat16, 64 | "bool": tf.bool, 65 | } 66 | 67 | def call(self, inputs, training=None, mask=None): 68 | outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs)) 69 | for feature, type_str in self.config.features.items(): 70 | assert type_str in self._type_map 71 | if type_str == "bfloat16": 72 | logging.warning( 73 | "Although bfloat16 and float32 have the same number of exponent bits, this down casting is not 100% lossless. Please double check metrics." 74 | ) 75 | down_cast_data_type = self._type_map[type_str] 76 | outputs[feature] = tf.cast(outputs[feature], dtype=down_cast_data_type) 77 | return outputs 78 | 79 | 80 | class RectifyLabels(tf.keras.Model): 81 | """Class for rectifying labels""" 82 | 83 | def __init__(self, rectify_label_config): 84 | super().__init__() 85 | self._config = rectify_label_config 86 | self._window = int(self._config.label_rectification_window_in_hours * 60 * 60 * 1000) 87 | 88 | def call(self, inputs, training=None, mask=None): 89 | served_ts_field = self._config.served_timestamp_field 90 | impressed_ts_field = self._config.impressed_timestamp_field 91 | 92 | for label, engaged_ts_field in self._config.label_to_engaged_timestamp_field.items(): 93 | impressed = inputs[impressed_ts_field] 94 | served = inputs[served_ts_field] 95 | engaged = inputs[engaged_ts_field] 96 | 97 | keep = tf.math.logical_and(inputs[label] > 0, impressed - served < self._window) 98 | keep = tf.math.logical_and(keep, engaged - served < self._window) 99 | inputs[label] = tf.where(keep, inputs[label], tf.zeros_like(inputs[label])) 100 | 101 | return inputs 102 | 103 | 104 | class ExtractFeatures(tf.keras.Model): 105 | """Class for extracting individual features from dense tensors by their index.""" 106 | 107 | def __init__(self, extract_features_config): 108 | super().__init__() 109 | self._config = extract_features_config 110 | 111 | def call(self, inputs, training=None, mask=None): 112 | 113 | for row in self._config.extract_feature_table: 114 | inputs[row.name] = inputs[row.source_tensor][:, row.index] 115 | 116 | return inputs 117 | 118 | 119 | class DownsampleNegatives(tf.keras.Model): 120 | """Class for down-sampling/dropping negatives and updating the weights. 121 | 122 | If inputs['fav'] = [1, 0, 0, 0] and inputs['weights'] = [1.0, 1.0, 1.0, 1.0] 123 | inputs are transformed to inputs['fav'] = [1, 0] and inputs['weights'] = [1.0, 3.0] 124 | when batch_multiplier=2 and engagements_list=['fav'] 125 | 126 | It supports multiple engagements (union/logical_or is used to aggregate engagements), so we don't 127 | drop positives for any engagement. 128 | """ 129 | 130 | def __init__(self, downsample_negatives_config): 131 | super().__init__() 132 | self.config = downsample_negatives_config 133 | 134 | def call(self, inputs, training=None, mask=None): 135 | labels = self.config.engagements_list 136 | # union of engagements 137 | mask = tf.squeeze(tf.reduce_any(tf.stack([inputs[label] == 1 for label in labels], 1), 1)) 138 | n_positives = tf.reduce_sum(tf.cast(mask, tf.int32)) 139 | batch_size = tf.cast(tf.shape(inputs[labels[0]])[0] / self.config.batch_multiplier, tf.int32) 140 | negative_weights = tf.math.divide_no_nan( 141 | tf.cast(self.config.batch_multiplier * batch_size - n_positives, tf.float32), 142 | tf.cast(batch_size - n_positives, tf.float32), 143 | ) 144 | new_weights = tf.cast(mask, tf.float32) + (1 - tf.cast(mask, tf.float32)) * negative_weights 145 | 146 | def _split_by_label_concatenate_and_truncate(input_tensor): 147 | # takes positive examples and concatenate with negative examples and truncate 148 | # DANGER: if n_positives > batch_size down-sampling is incorrect (do not use pb_50) 149 | return tf.concat( 150 | [ 151 | input_tensor[mask], 152 | input_tensor[tf.math.logical_not(mask)], 153 | ], 154 | 0, 155 | )[:batch_size] 156 | 157 | if "weights" not in inputs: 158 | # add placeholder so logic below applies even if weights aren't present in inputs 159 | inputs["weights"] = tf.ones([tf.shape(inputs[labels[0]])[0], self.config.num_engagements]) 160 | 161 | for tensor in inputs: 162 | if tensor == "weights": 163 | inputs[tensor] = inputs[tensor] * tf.reshape(new_weights, [-1, 1]) 164 | 165 | inputs[tensor] = _split_by_label_concatenate_and_truncate(inputs[tensor]) 166 | 167 | return inputs 168 | 169 | 170 | def build_preprocess(preprocess_config, mode=config_mod.JobMode.TRAIN): 171 | """Builds a preprocess model to apply all preprocessing stages.""" 172 | if mode == config_mod.JobMode.INFERENCE: 173 | logging.info("Not building preprocessors for dataloading since we are in Inference mode.") 174 | return None 175 | 176 | preprocess_models = [] 177 | if preprocess_config.downsample_negatives: 178 | preprocess_models.append(DownsampleNegatives(preprocess_config.downsample_negatives)) 179 | if preprocess_config.truncate_and_slice: 180 | preprocess_models.append(TruncateAndSlice(preprocess_config.truncate_and_slice)) 181 | if preprocess_config.downcast: 182 | preprocess_models.append(DownCast(preprocess_config.downcast)) 183 | if preprocess_config.rectify_labels: 184 | preprocess_models.append(RectifyLabels(preprocess_config.rectify_labels)) 185 | if preprocess_config.extract_features: 186 | preprocess_models.append(ExtractFeatures(preprocess_config.extract_features)) 187 | 188 | if len(preprocess_models) == 0: 189 | raise ValueError("No known preprocessor.") 190 | 191 | class PreprocessModel(tf.keras.Model): 192 | def __init__(self, preprocess_models): 193 | super().__init__() 194 | self.preprocess_models = preprocess_models 195 | 196 | def call(self, inputs, training=None, mask=None): 197 | outputs = inputs 198 | for model in self.preprocess_models: 199 | outputs = model(outputs, training, mask) 200 | return outputs 201 | 202 | if len(preprocess_models) > 1: 203 | logging.warning( 204 | "With multiple preprocessing models, we apply these models in a predefined order. Future works may introduce customized models and orders." 205 | ) 206 | return PreprocessModel(preprocess_models) 207 | -------------------------------------------------------------------------------- /projects/home/recap/data/config.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from enum import Enum 3 | 4 | 5 | from tml.core import config as base_config 6 | 7 | import pydantic 8 | 9 | 10 | class ExplicitDateInputs(base_config.BaseConfig): 11 | """Arguments to select train/validation data using end_date and days of data.""" 12 | 13 | data_root: str = pydantic.Field(..., description="Data path prefix.") 14 | end_date: str = pydantic.Field(..., description="Data end date, inclusive.") 15 | days: int = pydantic.Field(..., description="Number of days of data for dataset.") 16 | num_missing_days_tol: int = pydantic.Field( 17 | 0, description="We tolerate <= num_missing_days_tol days of missing data." 18 | ) 19 | 20 | 21 | class ExplicitDatetimeInputs(base_config.BaseConfig): 22 | """Arguments to select train/validation data using end_datetime and hours of data.""" 23 | 24 | data_root: str = pydantic.Field(..., description="Data path prefix.") 25 | end_datetime: str = pydantic.Field(..., description="Data end datetime, inclusive.") 26 | hours: int = pydantic.Field(..., description="Number of hours of data for dataset.") 27 | num_missing_hours_tol: int = pydantic.Field( 28 | 0, description="We tolerate <= num_missing_hours_tol hours of missing data." 29 | ) 30 | 31 | 32 | class DdsCompressionOption(str, Enum): 33 | """The only valid compression option is 'AUTO'""" 34 | 35 | AUTO = "AUTO" 36 | 37 | 38 | class DatasetConfig(base_config.BaseConfig): 39 | inputs: str = pydantic.Field( 40 | None, description="A glob for selecting data.", one_of="date_inputs_format" 41 | ) 42 | explicit_datetime_inputs: ExplicitDatetimeInputs = pydantic.Field( 43 | None, one_of="date_inputs_format" 44 | ) 45 | explicit_date_inputs: ExplicitDateInputs = pydantic.Field(None, one_of="date_inputs_format") 46 | 47 | global_batch_size: pydantic.PositiveInt 48 | 49 | num_files_to_keep: pydantic.PositiveInt = pydantic.Field( 50 | None, description="Number of shards to keep." 51 | ) 52 | repeat_files: bool = pydantic.Field( 53 | True, description="DEPRICATED. Files are repeated no matter what this is set to." 54 | ) 55 | file_batch_size: pydantic.PositiveInt = pydantic.Field(16, description="File batch size") 56 | 57 | cache: bool = pydantic.Field( 58 | False, 59 | description="Cache dataset in memory. Careful to only use this when you" 60 | " have enough memory to fit entire dataset.", 61 | ) 62 | 63 | data_service_dispatcher: str = pydantic.Field(None) 64 | ignore_data_errors: bool = pydantic.Field( 65 | False, description="Whether to ignore tf.data errors. DANGER DANGER, may wedge jobs." 66 | ) 67 | dataset_service_compression: DdsCompressionOption = pydantic.Field( 68 | None, 69 | description="Compress the dataset for DDS worker -> training host. Disabled by default and the only valid option is 'AUTO'", 70 | ) 71 | 72 | # tf.data.Dataset options 73 | examples_shuffle_buffer_size: int = pydantic.Field(1024, description="Size of shuffle buffers.") 74 | map_num_parallel_calls: pydantic.PositiveInt = pydantic.Field( 75 | None, description="Number of parallel calls." 76 | ) 77 | interleave_num_parallel_calls: pydantic.PositiveInt = pydantic.Field( 78 | None, description="Number of shards to interleave." 79 | ) 80 | 81 | 82 | class TruncateAndSlice(base_config.BaseConfig): 83 | # Apply truncation and then slice. 84 | continuous_feature_truncation: pydantic.PositiveInt = pydantic.Field( 85 | None, description="Experimental. Truncates continuous features to this amount for efficiency." 86 | ) 87 | binary_feature_truncation: pydantic.PositiveInt = pydantic.Field( 88 | None, description="Experimental. Truncates binary features to this amount for efficiency." 89 | ) 90 | 91 | continuous_feature_mask_path: str = pydantic.Field( 92 | None, description="Path of mask used to slice input continuous features." 93 | ) 94 | binary_feature_mask_path: str = pydantic.Field( 95 | None, description="Path of mask used to slice input binary features." 96 | ) 97 | 98 | 99 | class DataType(str, Enum): 100 | BFLOAT16 = "bfloat16" 101 | BOOL = "bool" 102 | 103 | FLOAT32 = "float32" 104 | FLOAT16 = "float16" 105 | 106 | UINT8 = "uint8" 107 | 108 | 109 | class DownCast(base_config.BaseConfig): 110 | # Apply down casting to selected features. 111 | features: typing.Dict[str, DataType] = pydantic.Field( 112 | None, description="Map features to down cast data types." 113 | ) 114 | 115 | 116 | class TaskData(base_config.BaseConfig): 117 | pos_downsampling_rate: float = pydantic.Field( 118 | 1.0, 119 | description="Downsampling rate of positives used to generate dataset.", 120 | ) 121 | neg_downsampling_rate: float = pydantic.Field( 122 | 1.0, 123 | description="Downsampling rate of negatives used to generate dataset.", 124 | ) 125 | 126 | 127 | class SegDenseSchema(base_config.BaseConfig): 128 | schema_path: str = pydantic.Field(..., description="Path to feature config json.") 129 | features: typing.List[str] = pydantic.Field( 130 | [], 131 | description="List of features (in addition to the renamed features) to read from schema path above.", 132 | ) 133 | renamed_features: typing.Dict[str, str] = pydantic.Field( 134 | {}, description="Dictionary of renamed features." 135 | ) 136 | mask_mantissa_features: typing.Dict[str, int] = pydantic.Field( 137 | {}, 138 | description="(experimental) Number of mantissa bits to mask to simulate lower precision data.", 139 | ) 140 | 141 | 142 | class RectifyLabels(base_config.BaseConfig): 143 | label_rectification_window_in_hours: float = pydantic.Field( 144 | 3.0, description="overlap time in hours for which to flip labels" 145 | ) 146 | served_timestamp_field: str = pydantic.Field( 147 | ..., description="input field corresponding to served time" 148 | ) 149 | impressed_timestamp_field: str = pydantic.Field( 150 | ..., description="input field corresponding to impressed time" 151 | ) 152 | label_to_engaged_timestamp_field: typing.Dict[str, str] = pydantic.Field( 153 | ..., description="label to the input field corresponding to engagement time" 154 | ) 155 | 156 | 157 | class ExtractFeaturesRow(base_config.BaseConfig): 158 | name: str = pydantic.Field( 159 | ..., 160 | description="name of the new field name to be created", 161 | ) 162 | source_tensor: str = pydantic.Field( 163 | ..., 164 | description="name of the dense tensor to look for the feature", 165 | ) 166 | index: int = pydantic.Field( 167 | ..., 168 | description="index of the feature in the dense tensor", 169 | ) 170 | 171 | 172 | class ExtractFeatures(base_config.BaseConfig): 173 | extract_feature_table: typing.List[ExtractFeaturesRow] = pydantic.Field( 174 | [], 175 | description="list of features to be extracted with their name, source tensor and index", 176 | ) 177 | 178 | 179 | class DownsampleNegatives(base_config.BaseConfig): 180 | batch_multiplier: int = pydantic.Field( 181 | None, 182 | description="batch multiplier", 183 | ) 184 | engagements_list: typing.List[str] = pydantic.Field( 185 | [], 186 | description="engagements with kept positives", 187 | ) 188 | num_engagements: int = pydantic.Field( 189 | ..., 190 | description="number engagements used in the model, including ones excluded in engagements_list", 191 | ) 192 | 193 | 194 | class Preprocess(base_config.BaseConfig): 195 | truncate_and_slice: TruncateAndSlice = pydantic.Field(None, description="Truncation and slicing.") 196 | downcast: DownCast = pydantic.Field(None, description="Down cast to features.") 197 | rectify_labels: RectifyLabels = pydantic.Field( 198 | None, description="Rectify labels for a given overlap window" 199 | ) 200 | extract_features: ExtractFeatures = pydantic.Field( 201 | None, description="Extract features from dense tensors." 202 | ) 203 | downsample_negatives: DownsampleNegatives = pydantic.Field( 204 | None, description="Downsample negatives." 205 | ) 206 | 207 | 208 | class Sampler(base_config.BaseConfig): 209 | """Assumes function is defined in data/samplers.py. 210 | 211 | Only use this for quick experimentation. 212 | If samplers are useful, we should sample from upstream data generation. 213 | 214 | DEPRICATED, DO NOT USE. 215 | """ 216 | 217 | name: str 218 | kwargs: typing.Dict 219 | 220 | 221 | class RecapDataConfig(DatasetConfig): 222 | seg_dense_schema: SegDenseSchema 223 | 224 | tasks: typing.Dict[str, TaskData] = pydantic.Field( 225 | description="Description of individual tasks in this dataset." 226 | ) 227 | evaluation_tasks: typing.List[str] = pydantic.Field( 228 | [], description="If specified, lists the tasks we're generating metrics for." 229 | ) 230 | 231 | preprocess: Preprocess = pydantic.Field( 232 | None, description="Function run in tf.data.Dataset at train/eval, in-graph at inference." 233 | ) 234 | 235 | sampler: Sampler = pydantic.Field( 236 | None, 237 | description="""DEPRICATED, DO NOT USE. Sampling function for offline experiments.""", 238 | ) 239 | 240 | @pydantic.root_validator() 241 | def _validate_evaluation_tasks(cls, values): 242 | if values.get("evaluation_tasks") is not None: 243 | for task in values["evaluation_tasks"]: 244 | if task not in values["tasks"]: 245 | raise KeyError(f"Evaluation task {task} must be in tasks. Received {values['tasks']}") 246 | return values 247 | --------------------------------------------------------------------------------