├── examples ├── __init__.py ├── ray │ ├── __init__.py │ ├── requirements.txt │ ├── compute_world_size.py │ └── README.md ├── zch │ ├── __init__.py │ └── docs │ │ └── mpzch_module_dataflow.png ├── golden_training │ ├── __init__.py │ ├── tests │ │ ├── __init__.py │ │ └── test_train_dlrm.py │ └── README.md ├── transfer_learning │ ├── __init__.py │ └── README.md ├── dlrm │ └── README.MD ├── retrieval │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ └── dataloader.py │ ├── modules │ │ ├── __init__.py │ │ └── tests │ │ │ └── __init__.py │ └── tests │ │ ├── __init__.py │ │ ├── test_two_tower_retrieval.py │ │ └── test_two_tower_train.py ├── prediction │ └── __init__.py ├── bert4rec │ └── models │ │ └── tests │ │ └── test_bert4rec.py └── nvt_dataloader │ ├── aws_component.py │ └── README.md ├── version.txt ├── torchrec ├── metrics │ └── __init__.py ├── schema │ └── __init__.py ├── models │ ├── tests │ │ └── __init__.py │ ├── experimental │ │ └── __init__.py │ └── __init__.py ├── optim │ ├── tests │ │ ├── __init__.py │ │ └── test_optim.py │ ├── test_utils │ │ └── __init__.py │ ├── fused.py │ └── __init__.py ├── sparse │ ├── tests │ │ └── __init__.py │ ├── __init__.py │ └── tensor_dict.py ├── datasets │ ├── scripts │ │ ├── __init__.py │ │ ├── nvt │ │ │ ├── utils │ │ │ │ ├── __init__.py │ │ │ │ ├── dask.py │ │ │ │ └── criteo_constant.py │ │ │ ├── Dockerfile │ │ │ └── nvt_preproc.sh │ │ └── tests │ │ │ └── test_npy_preproc_criteo.py │ ├── tests │ │ └── __init__.py │ ├── test_utils │ │ └── __init__.py │ └── __init__.py ├── distributed │ ├── tests │ │ ├── __init__.py │ │ ├── test_fused_optimizer.py │ │ ├── test_model_parallel_nccl.py │ │ ├── test_model_parallel_gloo_gpu.py │ │ ├── test_model_parallel_gloo_gpu_single_rank.py │ │ ├── test_model_parallel_gloo.py │ │ └── test_awaitable.py │ ├── benchmark │ │ ├── __init__.py │ │ ├── benchmark_zch │ │ │ ├── data │ │ │ │ ├── __init__.py │ │ │ │ ├── configs │ │ │ │ │ └── criteo_kaggle.yaml │ │ │ │ ├── Readme.md │ │ │ │ └── preprocess │ │ │ │ │ └── __init__.py │ │ │ ├── figures │ │ │ │ ├── eval_metrics_auc.png │ │ │ │ ├── eval_metrics_mae.png │ │ │ │ ├── eval_metrics_mse.png │ │ │ │ └── eval_metrics_ne.png │ │ │ └── models │ │ │ │ ├── models │ │ │ │ └── __init__.py │ │ │ │ ├── __init__.py │ │ │ │ └── configs │ │ │ │ ├── dlrmv3.yaml │ │ │ │ └── dlrmv2.yaml │ │ ├── README.md │ │ └── yaml │ │ │ ├── sparse_data_dist_ssd.yml │ │ │ ├── sparse_data_dist_emo.yml │ │ │ ├── sparse_data_dist_base.yml │ │ │ ├── base_pipeline_light.yml │ │ │ └── sparse_data_dist_base_vbe.yml │ ├── composable │ │ ├── __init__.py │ │ └── tests │ │ │ ├── __init__.py │ │ │ └── test_table_batched_embedding_slice.py │ ├── sharding │ │ ├── __init__.py │ │ └── twcw_sharding.py │ ├── test_utils │ │ └── __init__.py │ ├── logging_handlers.py │ ├── model_tracker │ │ ├── trackers │ │ │ └── __init__.py │ │ └── __init__.py │ ├── torchrec_logging_handlers.py │ ├── global_settings.py │ ├── planner │ │ ├── __init__.py │ │ └── perf_models.py │ └── train_pipeline │ │ └── __init__.py ├── modules │ ├── tests │ │ ├── __init__.py │ │ ├── test_activation.py │ │ └── test_code_quality.py │ ├── README.rst │ ├── object_pool.py │ ├── __init__.py │ ├── pruning_logger.py │ └── activation.py ├── inference │ ├── inference_legacy │ │ ├── src │ │ │ ├── Executer2.cpp │ │ │ └── Validation.cpp │ │ ├── include │ │ │ └── torchrec │ │ │ │ └── inference │ │ │ │ ├── JaggedTensor.h │ │ │ │ ├── Validation.h │ │ │ │ ├── TestUtils.h │ │ │ │ ├── ShardedTensor.h │ │ │ │ ├── ExceptionHandler.h │ │ │ │ └── Exception.h │ │ ├── __init__.py │ │ ├── protos │ │ │ └── predictor.proto │ │ └── tests │ │ │ ├── ValidationTest.cpp │ │ │ ├── test_modules.py │ │ │ └── predict_module_tests.py │ ├── __init__.py │ ├── include │ │ └── torchrec │ │ │ └── inference │ │ │ ├── JaggedTensor.h │ │ │ ├── Validation.h │ │ │ ├── TestUtils.h │ │ │ ├── ShardedTensor.h │ │ │ ├── ExceptionHandler.h │ │ │ ├── Exception.h │ │ │ └── ResultSplit.h │ └── protos │ │ └── predictor.proto ├── csrc │ ├── CMakeLists.txt │ └── dynamic_embedding │ │ ├── details │ │ ├── notification.cpp │ │ ├── types.h │ │ ├── notification.h │ │ ├── redis │ │ │ └── CMakeLists.txt │ │ ├── lxu_strategy.h │ │ ├── bits_op.h │ │ ├── io_parameter.h │ │ ├── io_registry.h │ │ ├── bitmap.h │ │ ├── bitmap_impl.h │ │ └── naive_id_transformer.h │ │ ├── CMakeLists.txt │ │ ├── bind.cpp │ │ └── id_transformer_wrapper.h ├── pt2 │ └── __init__.py ├── utils │ └── __init__.py ├── fx │ ├── __init__.py │ └── tests │ │ └── test_tracer.py ├── __init__.py ├── quant │ ├── __init__.py │ └── tests │ │ └── test_tensor_types.py ├── ir │ ├── types.py │ └── schema.py ├── streamable.py └── types.py ├── contrib └── dynamic_embedding │ ├── src │ ├── __init__.py │ ├── CMakeLists.txt │ ├── tde │ │ ├── details │ │ │ ├── move_only_function.cpp │ │ │ ├── naive_id_transformer.cpp │ │ │ ├── cacheline_id_transformer.cpp │ │ │ ├── redis_io.h │ │ │ ├── notification_test.cpp │ │ │ ├── move_only_function_test.cpp │ │ │ ├── notification.cpp │ │ │ ├── bits_op_test.cpp │ │ │ ├── notification.h │ │ │ ├── id_transformer_variant_test.cpp │ │ │ ├── bits_op.h │ │ │ ├── redis_io.cpp │ │ │ ├── random_bits_generator_benchmark.cpp │ │ │ ├── mixed_lfu_lru_strategy.cpp │ │ │ ├── cacheline_id_transformer_benchmark.cpp │ │ │ ├── naive_id_transformer_benchmark.cpp │ │ │ ├── mixed_lfu_lru_strategy_benchmark.cpp │ │ │ ├── url_test.cpp │ │ │ ├── move_only_function.h │ │ │ ├── id_transformer_variant_impl.h │ │ │ └── random_bits_generator.h │ │ ├── notification.h │ │ ├── tensor_list.h │ │ ├── id_transformer.h │ │ └── bind.cpp │ └── torchrec_dynamic_embedding │ │ ├── __init__.py │ │ ├── distributed │ │ └── __init__.py │ │ ├── utils.py │ │ ├── tensor_list.py │ │ └── id_transformer.py │ ├── tests │ ├── __init__.py │ ├── memory_io │ │ └── CMakeLists.txt │ └── utils.py │ ├── .gitignore │ ├── tools │ ├── repair_wheel.sh │ ├── build_wheels.sh │ └── before_linux_build.sh │ ├── CMakeLists.txt │ └── setup.py ├── docs ├── .gitignore ├── source │ ├── _static │ │ └── img │ │ │ ├── sharding.png │ │ │ ├── model_parallel.png │ │ │ ├── torchrec_forward.png │ │ │ ├── full_training_loop.png │ │ │ ├── fused_embedding_tables.png │ │ │ ├── fused_backward_optimizer.png │ │ │ └── card-background.svg │ ├── _templates │ │ └── layout.html │ ├── model-parallel-api-reference.rst │ ├── inference-api-reference.rst │ ├── datatypes-api-reference.rst │ ├── modules-api-reference.rst │ └── planner-api-reference.rst ├── requirements.txt ├── README.md ├── Makefile └── make.bat ├── .github ├── scripts │ ├── tests_to_skip.txt │ └── install_libs.sh └── workflows │ ├── pre-commit.yaml │ ├── pyre.yml │ ├── validate-nightly-binaries.yml │ └── validate-binaries.yml ├── pyproject.toml ├── install-requirements.txt ├── benchmarks ├── EBC_benchmarks_dlrm_emb.png └── cpp │ ├── CMakeLists.txt │ └── dynamic_embedding │ ├── CMakeLists.txt │ ├── random_bits_generator_benchmark.cpp │ ├── naive_id_transformer_benchmark.cpp │ └── mixed_lfu_lru_strategy_benchmark.cpp ├── rfc ├── RFC-0001-assets │ ├── cache-consistency.png │ └── logical-cache-consistency.png └── RFC-0002-assets │ ├── kv_tbe_prefetch_workflow.png │ ├── kv_tbe_training_high_level.png │ └── kv_tbe_pipeline_prefetching.png ├── test └── cpp │ ├── CMakeLists.txt │ └── dynamic_embedding │ ├── notification_test.cpp │ ├── redis │ ├── CMakeLists.txt │ └── url_test.cpp │ ├── bits_op_test.cpp │ └── CMakeLists.txt ├── requirements.txt ├── .pyre_configuration ├── .pre-commit-config.yaml ├── .lintrunner.toml ├── tools └── lint │ └── utils.py ├── CMakeLists.txt ├── LICENSE └── CONTRIBUTING.md /examples/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /version.txt: -------------------------------------------------------------------------------- 1 | 1.5.0a0 2 | -------------------------------------------------------------------------------- /examples/ray/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/zch/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /torchrec/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /torchrec/schema/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /torchrec/models/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /torchrec/optim/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /torchrec/sparse/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/golden_training/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/transfer_learning/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /torchrec/datasets/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /torchrec/datasets/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /torchrec/distributed/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /torchrec/modules/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /contrib/dynamic_embedding/src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /contrib/dynamic_embedding/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/golden_training/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /torchrec/datasets/test_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /torchrec/distributed/benchmark/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /torchrec/distributed/composable/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /torchrec/distributed/sharding/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /torchrec/distributed/test_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /torchrec/models/experimental/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | src/pytorch-sphinx-theme/ 2 | -------------------------------------------------------------------------------- /torchrec/datasets/scripts/nvt/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /torchrec/distributed/composable/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /torchrec/distributed/tests/test_fused_optimizer.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /torchrec/inference/inference_legacy/src/Executer2.cpp: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /torchrec/distributed/benchmark/benchmark_zch/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.github/scripts/tests_to_skip.txt: -------------------------------------------------------------------------------- 1 | _disabled_in_oss_compatibility 2 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.usort] 2 | 3 | first_party_detection = false 4 | -------------------------------------------------------------------------------- /contrib/dynamic_embedding/src/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(3rd) 2 | add_subdirectory(tde) 3 | -------------------------------------------------------------------------------- /contrib/dynamic_embedding/src/tde/details/move_only_function.cpp: -------------------------------------------------------------------------------- 1 | #include "move_only_function.h" 2 | -------------------------------------------------------------------------------- /contrib/dynamic_embedding/src/tde/details/naive_id_transformer.cpp: -------------------------------------------------------------------------------- 1 | #include "naive_id_transformer.h" 2 | -------------------------------------------------------------------------------- /contrib/dynamic_embedding/src/torchrec_dynamic_embedding/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataloader import save, wrap 2 | -------------------------------------------------------------------------------- /contrib/dynamic_embedding/.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | dist 3 | *egg-info 4 | cmake-build* 5 | _skbuild 6 | wheelhouse 7 | -------------------------------------------------------------------------------- /install-requirements.txt: -------------------------------------------------------------------------------- 1 | fbgemm-gpu 2 | tensordict 3 | torchmetrics==1.0.3 4 | tqdm 5 | pyre-extensions 6 | iopath 7 | -------------------------------------------------------------------------------- /contrib/dynamic_embedding/src/tde/details/cacheline_id_transformer.cpp: -------------------------------------------------------------------------------- 1 | #include "tde/details/cacheline_id_transformer.h" 2 | -------------------------------------------------------------------------------- /benchmarks/EBC_benchmarks_dlrm_emb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meta-pytorch/torchrec/HEAD/benchmarks/EBC_benchmarks_dlrm_emb.png -------------------------------------------------------------------------------- /docs/source/_static/img/sharding.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meta-pytorch/torchrec/HEAD/docs/source/_static/img/sharding.png -------------------------------------------------------------------------------- /docs/source/_static/img/model_parallel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meta-pytorch/torchrec/HEAD/docs/source/_static/img/model_parallel.png -------------------------------------------------------------------------------- /examples/zch/docs/mpzch_module_dataflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meta-pytorch/torchrec/HEAD/examples/zch/docs/mpzch_module_dataflow.png -------------------------------------------------------------------------------- /rfc/RFC-0001-assets/cache-consistency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meta-pytorch/torchrec/HEAD/rfc/RFC-0001-assets/cache-consistency.png -------------------------------------------------------------------------------- /docs/source/_static/img/torchrec_forward.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meta-pytorch/torchrec/HEAD/docs/source/_static/img/torchrec_forward.png -------------------------------------------------------------------------------- /docs/source/_static/img/full_training_loop.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meta-pytorch/torchrec/HEAD/docs/source/_static/img/full_training_loop.png -------------------------------------------------------------------------------- /examples/golden_training/README.md: -------------------------------------------------------------------------------- 1 | # Golden Training Example 2 | 3 | ## Running 4 | `torchx run -s local_cwd dist.ddp -j 1x2 --script train_dlrm.py` 5 | -------------------------------------------------------------------------------- /rfc/RFC-0002-assets/kv_tbe_prefetch_workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meta-pytorch/torchrec/HEAD/rfc/RFC-0002-assets/kv_tbe_prefetch_workflow.png -------------------------------------------------------------------------------- /contrib/dynamic_embedding/src/tde/details/redis_io.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace tde::details { 4 | 5 | extern void RegisterRedisIO(); 6 | 7 | } 8 | -------------------------------------------------------------------------------- /docs/source/_static/img/fused_embedding_tables.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meta-pytorch/torchrec/HEAD/docs/source/_static/img/fused_embedding_tables.png -------------------------------------------------------------------------------- /rfc/RFC-0001-assets/logical-cache-consistency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meta-pytorch/torchrec/HEAD/rfc/RFC-0001-assets/logical-cache-consistency.png -------------------------------------------------------------------------------- /rfc/RFC-0002-assets/kv_tbe_training_high_level.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meta-pytorch/torchrec/HEAD/rfc/RFC-0002-assets/kv_tbe_training_high_level.png -------------------------------------------------------------------------------- /docs/source/_static/img/fused_backward_optimizer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meta-pytorch/torchrec/HEAD/docs/source/_static/img/fused_backward_optimizer.png -------------------------------------------------------------------------------- /rfc/RFC-0002-assets/kv_tbe_pipeline_prefetching.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meta-pytorch/torchrec/HEAD/rfc/RFC-0002-assets/kv_tbe_pipeline_prefetching.png -------------------------------------------------------------------------------- /examples/dlrm/README.MD: -------------------------------------------------------------------------------- 1 | # TorchRec DLRM Example 2 | 3 | See [Facebookresearch/dlrm torchrec implementation](https://github.com/facebookresearch/dlrm/tree/main/torchrec_dlrm/) 4 | -------------------------------------------------------------------------------- /examples/ray/requirements.txt: -------------------------------------------------------------------------------- 1 | --pre torchrec -f https://download.pytorch.org/whl/torchrec/index.html 2 | torchx-nightly 3 | torch --extra-index-url https://download.pytorch.org/whl/cu113 4 | -------------------------------------------------------------------------------- /torchrec/distributed/benchmark/benchmark_zch/figures/eval_metrics_auc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meta-pytorch/torchrec/HEAD/torchrec/distributed/benchmark/benchmark_zch/figures/eval_metrics_auc.png -------------------------------------------------------------------------------- /torchrec/distributed/benchmark/benchmark_zch/figures/eval_metrics_mae.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meta-pytorch/torchrec/HEAD/torchrec/distributed/benchmark/benchmark_zch/figures/eval_metrics_mae.png -------------------------------------------------------------------------------- /torchrec/distributed/benchmark/benchmark_zch/figures/eval_metrics_mse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meta-pytorch/torchrec/HEAD/torchrec/distributed/benchmark/benchmark_zch/figures/eval_metrics_mse.png -------------------------------------------------------------------------------- /torchrec/distributed/benchmark/benchmark_zch/figures/eval_metrics_ne.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meta-pytorch/torchrec/HEAD/torchrec/distributed/benchmark/benchmark_zch/figures/eval_metrics_ne.png -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx==5.0.0 2 | pyre-extensions 3 | sphinx-design 4 | sphinx_copybutton 5 | # torch 6 | # PyTorch Theme 7 | -e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme 8 | -------------------------------------------------------------------------------- /docs/source/_templates/layout.html: -------------------------------------------------------------------------------- 1 | {% extends "!layout.html" %} 2 | 3 | {% block footer %} 4 | {{ super() }} 5 | 6 | 9 | 10 | {% endblock %} 11 | -------------------------------------------------------------------------------- /examples/retrieval/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | -------------------------------------------------------------------------------- /contrib/dynamic_embedding/tests/memory_io/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # A dummy ParamServer IO plugin. It just store params into memory 2 | add_library(memory_io SHARED memory_io.cpp) 3 | set_target_properties(memory_io PROPERTIES PREFIX "") 4 | target_compile_features(memory_io PUBLIC cxx_std_17) 5 | -------------------------------------------------------------------------------- /examples/retrieval/data/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | -------------------------------------------------------------------------------- /examples/retrieval/modules/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | -------------------------------------------------------------------------------- /examples/retrieval/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | -------------------------------------------------------------------------------- /test/cpp/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | add_subdirectory(dynamic_embedding) 8 | -------------------------------------------------------------------------------- /benchmarks/cpp/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | add_subdirectory(dynamic_embedding) 8 | -------------------------------------------------------------------------------- /examples/retrieval/modules/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | -------------------------------------------------------------------------------- /torchrec/csrc/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | add_subdirectory(dynamic_embedding) 8 | -------------------------------------------------------------------------------- /torchrec/distributed/benchmark/benchmark_zch/data/configs/criteo_kaggle.yaml: -------------------------------------------------------------------------------- 1 | dataset_path: "/home/lizhouyu/oss_github/dlrm/torchrec_dlrm/criteo_1tb/criteo_kaggle_processed" 2 | batch_size: 4096 3 | seed: 0 4 | multitask_configs: 5 | - task_name: is_click 6 | task_weight: 1 7 | task_type: classification 8 | -------------------------------------------------------------------------------- /torchrec/datasets/scripts/nvt/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/merlin/merlin-pytorch-training:nightly 2 | 3 | RUN conda install -y pytorch cudatoolkit=11.3 -c pytorch-nightly \ 4 | && pip install --pre torchrec_nightly -f https://download.pytorch.org/whl/nightly/torchrec_nightly/index.html 5 | 6 | WORKDIR /app 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | black 2 | click 3 | cmake 4 | fbgemm-gpu>=1.4.0 5 | hypothesis==6.70.1 6 | importlib-metadata 7 | iopath 8 | numpy 9 | pandas 10 | pyre-extensions 11 | scikit-build 12 | tensordict 13 | torchmetrics==1.0.3 14 | torchx 15 | tqdm 16 | usort 17 | parameterized 18 | PyYAML 19 | psutil 20 | expecttest>=0.3.0 21 | -------------------------------------------------------------------------------- /docs/source/model-parallel-api-reference.rst: -------------------------------------------------------------------------------- 1 | Model Parallel 2 | ---------------------------------- 3 | 4 | ``DistributedModelParallel`` is the main API for distributed training with TorchRec optimizations. 5 | 6 | 7 | .. automodule:: torchrec.distributed.model_parallel 8 | 9 | .. autoclass:: DistributedModelParallel 10 | :members: 11 | -------------------------------------------------------------------------------- /torchrec/pt2/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | # __init__ for python module packaging 11 | -------------------------------------------------------------------------------- /torchrec/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | from . import experimental # noqa 11 | -------------------------------------------------------------------------------- /torchrec/inference/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | from . import model_packager, modules # noqa # noqa 11 | -------------------------------------------------------------------------------- /contrib/dynamic_embedding/src/tde/details/notification_test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "gtest/gtest.h" 3 | #include "tde/details/notification.h" 4 | namespace tde::details { 5 | TEST(TDE, notification) { 6 | Notification notification; 7 | std::thread th([&] { notification.Done(); }); 8 | notification.Wait(); 9 | th.join(); 10 | } 11 | } // namespace tde::details 12 | -------------------------------------------------------------------------------- /contrib/dynamic_embedding/src/tde/details/move_only_function_test.cpp: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "tde/details/move_only_function.h" 3 | namespace tde::details { 4 | 5 | TEST(tde, move_only_function) { 6 | MoveOnlyFunction foo = +[] { return 0; }; 7 | ASSERT_EQ(foo(), 0); 8 | ASSERT_TRUE(foo); 9 | foo = {}; 10 | ASSERT_FALSE(foo); 11 | } 12 | 13 | } // namespace tde::details 14 | -------------------------------------------------------------------------------- /contrib/dynamic_embedding/tools/repair_wheel.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -xe 3 | WHEEL_FILE=$1 4 | DEST_DIR=$2 5 | 6 | CUDA_SUFFIX=cu$(echo "$CUDA_VERSION" | sed 's#\.##g') 7 | WHEEL_FILENAME=$(basename "${WHEEL_FILE}") 8 | DEST_FILENAME=$(echo "${WHEEL_FILENAME}" | sed -r 's#(torchrec_dynamic_embedding-[0-9]+\.[0-9]+\.[0-9]+)#\1'"+${CUDA_SUFFIX}#g") 9 | mv "${WHEEL_FILE}" "${DEST_DIR}/${DEST_FILENAME}" 10 | -------------------------------------------------------------------------------- /.pyre_configuration: -------------------------------------------------------------------------------- 1 | { 2 | "exclude": [ 3 | ".*/pyre-check/stubs/.*", 4 | ".*/torchrec/datasets*", 5 | ".*/torchrec/models*", 6 | ".*/torchrec/inference/client.py" 7 | ], 8 | "site_package_search_strategy": "all", 9 | "source_directories": [ 10 | { 11 | "import_root": ".", 12 | "source": "torchrec" 13 | } 14 | ], 15 | "strict": true, 16 | "version": "0.0.101729681899" 17 | } 18 | -------------------------------------------------------------------------------- /examples/prediction/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | 11 | def main() -> None: 12 | """DOC_STRING""" 13 | 14 | 15 | if __name__ == "__main__": 16 | main() 17 | -------------------------------------------------------------------------------- /torchrec/distributed/benchmark/benchmark_zch/models/models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | from .dlrmv2 import make_model_dlrmv2 10 | from .dlrmv3 import make_model_dlrmv3 11 | -------------------------------------------------------------------------------- /torchrec/modules/README.rst: -------------------------------------------------------------------------------- 1 | .. fbmeta:: 2 | hide_page_title=false 3 | 4 | TorchRec Module API Documentation 5 | ================================== 6 | 7 | This contains the information about the internal TorchRec Common Modules. 8 | With TOC tree, but it isn't showing up.... 9 | 10 | TorchRec Common Modules 11 | 12 | 13 | .. tocTree:: 14 | :glob 15 | 16 | 17 | 18 | .. automodule:: torchrec.modules 19 | :members: 20 | -------------------------------------------------------------------------------- /docs/source/_static/img/card-background.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /torchrec/distributed/logging_handlers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import logging 9 | 10 | 11 | __all__: list[str] = [] 12 | 13 | _log_handlers: dict[str, logging.Handler] = { 14 | "default": logging.NullHandler(), 15 | } 16 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.5.0 4 | hooks: 5 | - id: check-toml 6 | - id: check-yaml 7 | exclude: packaging/.* 8 | - id: end-of-file-fixer 9 | 10 | - repo: https://github.com/omnilib/ufmt 11 | rev: v2.5.1 12 | hooks: 13 | - id: ufmt 14 | additional_dependencies: 15 | - black == 24.2.0 16 | - usort == 1.0.8.post1 17 | -------------------------------------------------------------------------------- /contrib/dynamic_embedding/src/tde/notification.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include "tde/details/notification.h" 4 | 5 | namespace tde { 6 | 7 | class Notification : public torch::CustomClassHolder { 8 | public: 9 | void Done() { 10 | return notification_.Done(); 11 | } 12 | void Wait() { 13 | return notification_.Wait(); 14 | } 15 | 16 | private: 17 | details::Notification notification_; 18 | }; 19 | 20 | } // namespace tde 21 | -------------------------------------------------------------------------------- /torchrec/distributed/model_tracker/trackers/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | """MPZCH Raw ID Tracker 11 | """ 12 | 13 | from torchrec.distributed.model_tracker.trackers.raw_id_tracker import ( # noqa 14 | RawIdTracker, 15 | ) 16 | -------------------------------------------------------------------------------- /torchrec/distributed/benchmark/benchmark_zch/data/Readme.md: -------------------------------------------------------------------------------- 1 | # Datasets for zero collision hash benchmark 2 | 3 | ## Folder structure 4 | - `configs/`: Configs for each dataset, named as `{dataset_name}.json` 5 | - `preprocess`: Include scripts to preprocess the dataset to make the returned dataset in the format of 6 | - batch 7 | - dense_features 8 | - sparse_features 9 | - labels 10 | - `get_dataloader.py`: the entry point to get the dataloader for each dataset 11 | -------------------------------------------------------------------------------- /torchrec/distributed/tests/test_model_parallel_nccl.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | from torchrec.distributed.test_utils.test_model_parallel import ModelParallelBase 11 | 12 | 13 | class ModelParallelTestNccl(ModelParallelBase): 14 | pass 15 | -------------------------------------------------------------------------------- /contrib/dynamic_embedding/src/tde/details/notification.cpp: -------------------------------------------------------------------------------- 1 | #include "notification.h" 2 | 3 | namespace tde::details { 4 | void Notification::Done() { 5 | { 6 | std::lock_guard guard(mtx_); 7 | set_ = true; 8 | } 9 | cv_.notify_all(); 10 | } 11 | void Notification::Wait() { 12 | std::unique_lock lock(mtx_); 13 | cv_.wait(lock, [this] { return set_; }); 14 | } 15 | 16 | void Notification::Clear() { 17 | set_ = false; 18 | } 19 | } // namespace tde::details 20 | -------------------------------------------------------------------------------- /torchrec/distributed/benchmark/benchmark_zch/models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | from .apply_optimizers import ( 10 | apply_dense_optimizers, 11 | apply_sparse_optimizers, 12 | combine_optimizers, 13 | ) 14 | from .make_model import make_model 15 | -------------------------------------------------------------------------------- /contrib/dynamic_embedding/src/tde/details/bits_op_test.cpp: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "tde/details/bits_op.h" 3 | 4 | namespace tde::details { 5 | TEST(TDE, bits_op_Clz) { 6 | ASSERT_EQ(Clz(int32_t(0x7FFFFFFF)), 1); 7 | ASSERT_EQ(Clz(int64_t(0x7FFFFFFFFFFFFFFF)), 1); 8 | ASSERT_EQ(Clz(int8_t(0x7F)), 1); 9 | } 10 | 11 | TEST(TDE, bits_op_Ctz) { 12 | ASSERT_EQ(Ctz(int32_t(0x2)), 1); 13 | ASSERT_EQ(Ctz(int64_t(0xF00)), 8); 14 | ASSERT_EQ(Ctz(int8_t(0x4)), 2); 15 | } 16 | } // namespace tde::details 17 | -------------------------------------------------------------------------------- /contrib/dynamic_embedding/src/torchrec_dynamic_embedding/distributed/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | #!/usr/bin/env python3 9 | 10 | from .comm import ( 11 | broadcast_ids_to_evict, 12 | broadcast_transform_result, 13 | gather_global_ids, 14 | scatter_cache_ids, 15 | ) 16 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | Docs 2 | ========== 3 | 4 | 5 | ## Building the docs 6 | 7 | To build and preview the docs run the following commands: 8 | 9 | ```bash 10 | cd docs 11 | pip3 install -r requirements.txt 12 | make html 13 | python3 -m http.server 8082 --bind :: 14 | ``` 15 | 16 | Now you should be able to view the docs in your browser at the link provided in your terminal. 17 | 18 | To reload the preview after making changes, rerun: 19 | 20 | ```bash 21 | make html 22 | python3 -m http.server 8082 --bind :: 23 | ``` 24 | -------------------------------------------------------------------------------- /torchrec/distributed/torchrec_logging_handlers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | import logging 10 | 11 | from torchrec.distributed.logging_handlers import _log_handlers 12 | 13 | TORCHREC_LOGGER_NAME = "torchrec" 14 | 15 | _log_handlers.update({TORCHREC_LOGGER_NAME: logging.NullHandler()}) 16 | -------------------------------------------------------------------------------- /torchrec/fx/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | """Torchrec Tracer 11 | 12 | Custom FX tracer for torchrec 13 | 14 | See `Torch.FX documentation `_ 15 | """ 16 | 17 | from torchrec.fx.tracer import symbolic_trace, Tracer # noqa 18 | -------------------------------------------------------------------------------- /contrib/dynamic_embedding/src/tde/details/notification.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | namespace tde::details { 6 | 7 | /** 8 | * Multi-thread notification 9 | */ 10 | class Notification { 11 | public: 12 | void Done(); 13 | void Wait(); 14 | 15 | /** 16 | * Clear the set status. 17 | * 18 | * NOTE: Clear is not thread-safe. 19 | */ 20 | void Clear(); 21 | 22 | private: 23 | bool set_{false}; 24 | std::mutex mtx_; 25 | std::condition_variable cv_; 26 | }; 27 | 28 | } // namespace tde::details 29 | -------------------------------------------------------------------------------- /torchrec/distributed/benchmark/benchmark_zch/data/preprocess/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | from .criteo_kaggle import get_criteo_kaggle_dataloader 10 | from .kuairand_1k import get_kuairand_1k_dataloader 11 | from .kuairand_27k import get_kuairand_27k_dataloader 12 | from .movielens_1m import get_movielens_1m_dataloader 13 | -------------------------------------------------------------------------------- /torchrec/distributed/global_settings.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | PROPOGATE_DEVICE: bool = False 11 | 12 | 13 | def set_propogate_device(val: bool) -> None: 14 | global PROPOGATE_DEVICE 15 | PROPOGATE_DEVICE = val 16 | 17 | 18 | def get_propogate_device() -> bool: 19 | global PROPOGATE_DEVICE 20 | return PROPOGATE_DEVICE 21 | -------------------------------------------------------------------------------- /contrib/dynamic_embedding/tools/build_wheels.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd $(dirname $0)/../ 3 | export CIBW_BEFORE_BUILD="tools/before_linux_build.sh" 4 | 5 | # Use env CIBW_BUILD="cp*-manylinux_x86_64" tools/build_wheels.sh to build 6 | # all kinds of CPython. 7 | export CIBW_BUILD=${CIBW_BUILD:-"cp39-manylinux_x86_64"} 8 | 9 | export CIBW_MANYLINUX_X86_64_IMAGE=${CIBW_MANYLINUX_X86_64_IMAGE:-"manylinux_2_28"} 10 | 11 | # Do not auditwheels since tde uses torch's shared libraries. 12 | export CIBW_REPAIR_WHEEL_COMMAND="tools/repair_wheel.sh {wheel} {dest_dir}" 13 | 14 | cibuildwheel --platform linux --archs x86_64 15 | -------------------------------------------------------------------------------- /torchrec/distributed/tests/test_model_parallel_gloo_gpu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | from torchrec.distributed.test_utils.test_model_parallel import ModelParallelBase 11 | 12 | # GPU tests for Gloo. 13 | 14 | 15 | class ModelParallelTestGloo(ModelParallelBase): 16 | def setUp(self, backend: str = "gloo") -> None: 17 | super().setUp(backend=backend) 18 | -------------------------------------------------------------------------------- /test/cpp/dynamic_embedding/notification_test.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | namespace torchrec { 14 | TEST(TDE, notification) { 15 | Notification notification; 16 | std::thread th([&] { notification.done(); }); 17 | notification.wait(); 18 | th.join(); 19 | } 20 | } // namespace torchrec 21 | -------------------------------------------------------------------------------- /examples/transfer_learning/README.md: -------------------------------------------------------------------------------- 1 | # Transfer Learning example 2 | 3 | This examples showcases training a distributed model using TorchRec. The embeddings are initialized with pretrained values (assumed to be loaded from storage, such as parquet). The value is large enough that we use the `share_memory_` API to load the tensors from shared memory. 4 | 5 | See [`torch.multiprocessing`](https://pytorch.org/docs/stable/multiprocessing.html) and [best practices](https://pytorch.org/docs/1.6.0/notes/multiprocessing.html?highlight=multiprocessing) for more information on shared memory. 6 | 7 | ## Running 8 | 9 | `python train_from_pretrained_embedding.py` 10 | -------------------------------------------------------------------------------- /test/cpp/dynamic_embedding/redis/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | function(add_redis_test NAME) 8 | add_executable(${NAME} ${ARGN}) 9 | target_link_libraries(${NAME} redis_io tde_cpp_objs gtest gtest_main) 10 | add_test(NAME ${NAME} COMMAND ${NAME}) 11 | endfunction() 12 | 13 | # TODO: Need start a empty redis-server 14 | # on 127.0.0.1:6379 before run *redis*_test. 15 | add_redis_test(redis_io_test redis_io_test.cpp) 16 | add_redis_test(url_test url_test.cpp) 17 | -------------------------------------------------------------------------------- /torchrec/optim/test_utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | from typing import Any 11 | 12 | from torchrec.optim.keyed import KeyedOptimizer 13 | 14 | 15 | class DummyKeyedOptimizer(KeyedOptimizer): 16 | def __init__(self, *args: Any, **kwargs: Any) -> None: 17 | super().__init__(*args, **kwargs) 18 | 19 | # pyre-ignore[2] 20 | def step(self, closure: Any) -> None: 21 | pass # Override NotImplementedError. 22 | -------------------------------------------------------------------------------- /test/cpp/dynamic_embedding/redis/url_test.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include 10 | #include 11 | 12 | namespace torchrec::url_parser::rules { 13 | 14 | TEST(TDE, url) { 15 | auto url = parse_url("www.qq.com/?a=b&&c=d"); 16 | ASSERT_EQ(url.host, "www.qq.com"); 17 | ASSERT_TRUE(url.param.has_value()); 18 | ASSERT_EQ("a=b&&c=d", url.param.value()); 19 | } 20 | 21 | } // namespace torchrec::url_parser::rules 22 | -------------------------------------------------------------------------------- /.lintrunner.toml: -------------------------------------------------------------------------------- 1 | [[linter]] 2 | code = 'BLACK' 3 | include_patterns = ['**/*.py'] 4 | command = [ 5 | 'python3', 6 | 'tools/lint/black_linter.py', 7 | '--', 8 | '@{{PATHSFILE}}' 9 | ] 10 | init_command = [ 11 | 'python3', 12 | 'tools/lint/pip_init.py', 13 | '--dry-run={{DRYRUN}}', 14 | 'black==24.2.0', 15 | ] 16 | is_formatter = true 17 | 18 | [[linter]] 19 | code = 'USORT' 20 | include_patterns = ['**/*.py'] 21 | command = [ 22 | 'python3', 23 | 'tools/lint/usort_linter.py', 24 | '--', 25 | '@{{PATHSFILE}}' 26 | ] 27 | init_command = [ 28 | 'python3', 29 | 'tools/lint/pip_init.py', 30 | '--dry-run={{DRYRUN}}', 31 | 'usort==1.0.8', 32 | ] 33 | is_formatter = true 34 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /torchrec/csrc/dynamic_embedding/details/notification.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "notification.h" 10 | 11 | namespace torchrec { 12 | void Notification::done() { 13 | { 14 | std::lock_guard guard(mu_); 15 | set_ = true; 16 | } 17 | cv_.notify_all(); 18 | } 19 | void Notification::wait() { 20 | std::unique_lock lock(mu_); 21 | cv_.wait(lock, [this] { return set_; }); 22 | } 23 | 24 | void Notification::clear() { 25 | set_ = false; 26 | } 27 | } // namespace torchrec 28 | -------------------------------------------------------------------------------- /torchrec/inference/include/torchrec/inference/JaggedTensor.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #include 12 | #include 13 | 14 | #include 15 | 16 | namespace torchrec { 17 | 18 | struct JaggedTensor { 19 | at::Tensor lengths; 20 | at::Tensor values; 21 | at::Tensor weights; 22 | }; 23 | 24 | struct KeyedJaggedTensor { 25 | std::vector keys; 26 | at::Tensor lengths; 27 | at::Tensor values; 28 | at::Tensor weights; 29 | }; 30 | 31 | } // namespace torchrec 32 | -------------------------------------------------------------------------------- /contrib/dynamic_embedding/src/tde/details/id_transformer_variant_test.cpp: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "tde/details/id_transformer_variant.h" 3 | 4 | namespace tde::details { 5 | 6 | TEST(TDE, CreateLXUStrategy) { 7 | auto strategy = IDTransformer::LXUStrategy( 8 | {{"min_used_freq_power", 6}, {"type", "mixed_lru_lfu"}}); 9 | } 10 | 11 | TEST(TDE, IDTransformer) { 12 | IDTransformer transformer(1000, nlohmann::json::parse(R"( 13 | { 14 | "lxu_strategy": {"type": "mixed_lru_lfu"}, 15 | "id_transformer": {"type": "naive"} 16 | } 17 | )")); 18 | std::vector vec{0, 1, 2}; 19 | std::vector result; 20 | result.resize(vec.size()); 21 | transformer.Transform(vec, result); 22 | } 23 | 24 | } // namespace tde::details 25 | -------------------------------------------------------------------------------- /docs/source/inference-api-reference.rst: -------------------------------------------------------------------------------- 1 | Inference 2 | ---------------------------------- 3 | 4 | TorchRec provides easy-to-use APIs for transforming an authored TorchRec model 5 | into an optimized inference model for distributed inference, via eager module swaps. 6 | 7 | This transforms TorchRec modules like ``EmbeddingBagCollection`` in the model to 8 | a quantized, sharded version that can be compiled using torch.fx and TorchScript 9 | for inference in a C++ environment. 10 | 11 | The intended use is calling ``quantize_inference_model`` on the model followed by 12 | ``shard_quant_model``. 13 | 14 | .. codeblock:: 15 | 16 | .. automodule:: torchrec.inference.modules 17 | 18 | .. autofunction:: quantize_inference_model 19 | .. autofunction:: shard_quant_model 20 | -------------------------------------------------------------------------------- /torchrec/inference/inference_legacy/include/torchrec/inference/JaggedTensor.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #include 12 | #include 13 | 14 | #include 15 | 16 | namespace torchrec { 17 | 18 | struct JaggedTensor { 19 | at::Tensor lengths; 20 | at::Tensor values; 21 | at::Tensor weights; 22 | }; 23 | 24 | struct KeyedJaggedTensor { 25 | std::vector keys; 26 | at::Tensor lengths; 27 | at::Tensor values; 28 | at::Tensor weights; 29 | }; 30 | 31 | } // namespace torchrec 32 | -------------------------------------------------------------------------------- /test/cpp/dynamic_embedding/bits_op_test.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include 10 | #include 11 | 12 | namespace torchrec { 13 | TEST(TDE, bits_op_clz) { 14 | ASSERT_EQ(clz(int32_t(0x7FFFFFFF)), 1); 15 | ASSERT_EQ(clz(int64_t(0x7FFFFFFFFFFFFFFF)), 1); 16 | ASSERT_EQ(clz(int8_t(0x7F)), 1); 17 | } 18 | 19 | TEST(TDE, bits_op_ctz) { 20 | ASSERT_EQ(ctz(int32_t(0x2)), 1); 21 | ASSERT_EQ(ctz(int64_t(0xF00)), 8); 22 | ASSERT_EQ(ctz(int8_t(0x4)), 2); 23 | } 24 | } // namespace torchrec 25 | -------------------------------------------------------------------------------- /torchrec/distributed/benchmark/benchmark_zch/models/configs/dlrmv3.yaml: -------------------------------------------------------------------------------- 1 | hstu_num_heads: 4 # 1 for hstu, 2 for hstu-large 2 | hstu_attn_num_layers: 3 # 2 for hstu, 8 for hstu-large 3 | hstu_embedding_table_dim: 256 4 | hstu_transducer_embedding_dim: 512 5 | hstu_attn_linear_dim: 128 6 | hstu_attn_qk_dim: 128 7 | hstu_input_dropout_ratio: 0.2 8 | hstu_linear_dropout_rate: 0.1 9 | causal_multitask_weights: 0.2 10 | num_embeddings: 100000 11 | embedding_module_attribute_path: "dlrm_hstu._embedding_collection" # the attribute path of embedding module after model 12 | managed_collision_module_attribute_path: "module.dlrm_hstu._embedding_collection.mc_embedding_collection._managed_collision_collection._managed_collision_modules" # the attribute path of managed collision module after model 13 | -------------------------------------------------------------------------------- /contrib/dynamic_embedding/src/tde/tensor_list.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | namespace tde { 5 | 6 | class TensorList : public torch::CustomClassHolder { 7 | using Container = std::vector; 8 | 9 | public: 10 | TensorList() = default; 11 | 12 | void push_back(at::Tensor tensor) { 13 | tensors_.push_back(tensor); 14 | } 15 | int64_t size() const { 16 | return tensors_.size(); 17 | } 18 | torch::Tensor& operator[](int64_t index) { 19 | return tensors_[index]; 20 | } 21 | 22 | Container::const_iterator begin() const { 23 | return tensors_.begin(); 24 | } 25 | Container::const_iterator end() const { 26 | return tensors_.end(); 27 | } 28 | 29 | private: 30 | Container tensors_; 31 | }; 32 | 33 | } // namespace tde 34 | -------------------------------------------------------------------------------- /contrib/dynamic_embedding/src/torchrec_dynamic_embedding/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch.nn as nn 4 | from torchrec.distributed.types import ShardingPlan 5 | 6 | 7 | __all__ = [] 8 | 9 | 10 | def _get_sharded_modules_recursive( 11 | module: nn.Module, 12 | path: str, 13 | plan: ShardingPlan, 14 | ) -> Dict[str, nn.Module]: 15 | """ 16 | Get all sharded modules of module from `plan`. 17 | """ 18 | params_plan = plan.get_plan_for_module(path) 19 | if params_plan: 20 | return {path: (module, params_plan)} 21 | 22 | res = {} 23 | for name, child in module.named_children(): 24 | new_path = f"{path}.{name}" if path else name 25 | res.update(_get_sharded_modules_recursive(child, new_path, plan)) 26 | return res 27 | -------------------------------------------------------------------------------- /benchmarks/cpp/dynamic_embedding/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | function(add_tde_benchmark NAME) 8 | add_executable(${NAME} ${ARGN}) 9 | target_link_libraries(${NAME} tde_cpp_objs benchmark::benchmark_main benchmark::benchmark) 10 | endfunction() 11 | 12 | add_tde_benchmark(naive_id_transformer_benchmark naive_id_transformer_benchmark.cpp) 13 | add_tde_benchmark(random_bits_generator_benchmark random_bits_generator_benchmark.cpp) 14 | add_tde_benchmark(mixed_lfu_lru_strategy_benchmark mixed_lfu_lru_strategy_benchmark.cpp) 15 | add_tde_benchmark(mixed_lfu_lru_strategy_evict_benchmark mixed_lfu_lru_strategy_evict_benchmark.cpp) 16 | -------------------------------------------------------------------------------- /torchrec/csrc/dynamic_embedding/details/types.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | namespace torchrec { 16 | 17 | using lxu_record_t = uint32_t; 18 | 19 | struct record_t { 20 | int64_t global_id; 21 | int64_t cache_id; 22 | lxu_record_t lxu_record; 23 | }; 24 | 25 | using iterator_t = std::function()>; 26 | using update_t = 27 | std::function)>; 28 | using fetch_t = std::function; 29 | 30 | } // namespace torchrec 31 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yaml: -------------------------------------------------------------------------------- 1 | name: pre-commit 2 | 3 | on: 4 | push: 5 | branches: 6 | # only run tests on main branch & nightly; release should be triggered manually 7 | - nightly 8 | - main 9 | tags: 10 | # Release candidate tags look like: v1.11.0-rc1 11 | - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ 12 | pull_request: 13 | 14 | jobs: 15 | pre-commit: 16 | runs-on: ubuntu-latest 17 | steps: 18 | - name: Setup Python 19 | uses: actions/setup-python@v5 20 | with: 21 | python-version: 3.12 22 | architecture: x64 23 | packages: | 24 | ufmt==2.5.1 25 | black==24.2.0 26 | usort==1.0.8 27 | - name: Checkout Torchrec 28 | uses: actions/checkout@v4 29 | - name: Run pre-commit 30 | uses: pre-commit/action@v3.0.1 31 | -------------------------------------------------------------------------------- /contrib/dynamic_embedding/tests/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.distributed as dist 5 | 6 | 7 | __all__ = [] 8 | 9 | 10 | MEMORY_IO_REGISTERED = False 11 | 12 | 13 | def register_memory_io(): 14 | global MEMORY_IO_REGISTERED 15 | if not MEMORY_IO_REGISTERED: 16 | mem_io_path = os.getenv("TDE_MEMORY_IO_PATH") 17 | if mem_io_path is None: 18 | raise RuntimeError("env TDE_MEMORY_IO_PATH must set for unittest") 19 | 20 | torch.ops.tde.register_io(mem_io_path) 21 | MEMORY_IO_REGISTERED = True 22 | 23 | 24 | def init_dist(): 25 | if not dist.is_initialized(): 26 | os.environ["RANK"] = "0" 27 | os.environ["WORLD_SIZE"] = "1" 28 | os.environ["MASTER_ADDR"] = "127.0.0.1" 29 | os.environ["MASTER_PORT"] = "13579" 30 | dist.init_process_group("nccl") 31 | -------------------------------------------------------------------------------- /docs/source/datatypes-api-reference.rst: -------------------------------------------------------------------------------- 1 | Data Types 2 | ------------------- 3 | 4 | 5 | TorchRec contains data types for representing embedding, otherwise known as sparse features. 6 | Sparse features are typically indices that are meant to be fed into embedding tables. For a given 7 | batch, the number of embedding lookup indices are variable. Therefore, there is a need for a **jagged** 8 | dimension to represent the variable amount of embedding lookup indices for a batch. 9 | 10 | This section covers the classes for the 3 TorchRec data types for representing sparse features: 11 | **JaggedTensor**, **KeyedJaggedTensor**, and **KeyedTensor**. 12 | 13 | .. automodule:: torchrec.sparse.jagged_tensor 14 | 15 | .. autoclass:: JaggedTensor 16 | :members: 17 | 18 | .. autoclass:: KeyedJaggedTensor 19 | :members: 20 | 21 | .. autoclass:: KeyedTensor 22 | :members: 23 | -------------------------------------------------------------------------------- /test/cpp/dynamic_embedding/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | function(add_tde_test NAME) 8 | add_executable(${NAME} ${ARGN}) 9 | target_link_libraries(${NAME} tde_cpp_objs gtest gtest_main) 10 | add_test(NAME ${NAME} COMMAND ${NAME}) 11 | endfunction() 12 | 13 | add_tde_test(bits_op_test bits_op_test.cpp) 14 | add_tde_test(naive_id_transformer_test naive_id_transformer_test.cpp) 15 | add_tde_test(random_bits_generator_test random_bits_generator_test.cpp) 16 | add_tde_test(mixed_lfu_lru_strategy_test mixed_lfu_lru_strategy_test.cpp) 17 | add_tde_test(notification_test notification_test.cpp) 18 | 19 | if (BUILD_REDIS_IO) 20 | add_subdirectory(redis) 21 | endif() 22 | -------------------------------------------------------------------------------- /torchrec/distributed/tests/test_model_parallel_gloo_gpu_single_rank.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | from torchrec.distributed.test_utils.test_model_parallel_base import ( 11 | ModelParallelSparseOnlyBase, 12 | ModelParallelStateDictBase, 13 | ) 14 | 15 | # Single rank GPU tests for Gloo. 16 | 17 | 18 | class ModelParallelStateDictTestGloo(ModelParallelStateDictBase): 19 | def setUp(self, backend: str = "gloo") -> None: 20 | super().setUp(backend=backend) 21 | 22 | 23 | class ModelParallelSparseOnlyTestGloo(ModelParallelSparseOnlyBase): 24 | def setUp(self, backend: str = "gloo") -> None: 25 | super().setUp(backend=backend) 26 | -------------------------------------------------------------------------------- /.github/workflows/pyre.yml: -------------------------------------------------------------------------------- 1 | name: Pyre Check 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | 8 | jobs: 9 | pyre-check: 10 | runs-on: ubuntu-22.04 11 | defaults: 12 | run: 13 | shell: bash -el {0} 14 | steps: 15 | - uses: conda-incubator/setup-miniconda@v2 16 | with: 17 | python-version: 3.9 18 | - name: Checkout Torchrec 19 | uses: actions/checkout@v4 20 | - name: Install dependencies 21 | run: > 22 | pip install torch --index-url https://download.pytorch.org/whl/nightly/cpu && 23 | pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/cpu && 24 | pip install -r requirements.txt && 25 | pip install pyre-check-nightly==$(cat .pyre_configuration | grep version | awk '{print $2}' | sed 's/\"//g') 26 | - name: Pyre check 27 | run: pyre check 28 | -------------------------------------------------------------------------------- /.github/workflows/validate-nightly-binaries.yml: -------------------------------------------------------------------------------- 1 | # Scheduled validation of the nightly binaries 2 | name: validate-nightly-binaries 3 | 4 | on: 5 | schedule: 6 | # At 5:30 pm UTC (7:30 am PDT) 7 | - cron: "30 17 * * *" 8 | # Have the ability to trigger this job manually through the API 9 | workflow_dispatch: 10 | push: 11 | branches: 12 | - main 13 | paths: 14 | - '.github/workflows/validate-nightly-binaries.yml' 15 | - '.github/workflows/validate-binaries.yml' 16 | - '.github/scripts/validate-binaries.sh' 17 | pull_request: 18 | paths: 19 | - '.github/workflows/validate-nightly-binaries.yml' 20 | - '.github/workflows/validate-binaries.yml' 21 | - '.github/scripts/validate_binaries.sh' 22 | - '.github/scripts/filter.py' 23 | jobs: 24 | nightly: 25 | uses: ./.github/workflows/validate-binaries.yml 26 | with: 27 | channel: nightly 28 | -------------------------------------------------------------------------------- /torchrec/csrc/dynamic_embedding/details/notification.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | #include 11 | #include 12 | #include 13 | 14 | namespace torchrec { 15 | 16 | /** 17 | * Multi-thread notification 18 | */ 19 | class Notification : public torch::CustomClassHolder { 20 | public: 21 | Notification() = default; 22 | 23 | void done(); 24 | void wait(); 25 | 26 | /** 27 | * Clear the set status. 28 | * 29 | * NOTE: Clear is not thread-safe. 30 | */ 31 | void clear(); 32 | 33 | private: 34 | bool set_{false}; 35 | std::mutex mu_; 36 | std::condition_variable cv_; 37 | }; 38 | 39 | } // namespace torchrec 40 | -------------------------------------------------------------------------------- /docs/source/modules-api-reference.rst: -------------------------------------------------------------------------------- 1 | Modules 2 | ---------------------------------- 3 | 4 | Standard TorchRec modules represent collections of embedding tables: 5 | 6 | * ``EmbeddingBagCollection`` is a collection of ``torch.nn.EmbeddingBag`` 7 | * ``EmbeddingCollection`` is a collection of ``torch.nn.Embedding`` 8 | 9 | These modules are constructed through standardized config classes: 10 | 11 | * ``EmbeddingBagConfig`` for ``EmbeddingBagCollection`` 12 | * ``EmbeddingConfig`` for ``EmbeddingCollection`` 13 | 14 | .. automodule:: torchrec.modules.embedding_configs 15 | 16 | .. autoclass:: EmbeddingBagConfig 17 | :show-inheritance: 18 | 19 | .. autoclass:: EmbeddingConfig 20 | :show-inheritance: 21 | 22 | .. autoclass:: BaseEmbeddingConfig 23 | 24 | .. automodule:: torchrec.modules.embedding_modules 25 | 26 | .. autoclass:: EmbeddingBagCollection 27 | :members: 28 | 29 | .. autoclass:: EmbeddingCollection 30 | :members: 31 | -------------------------------------------------------------------------------- /torchrec/modules/object_pool.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import abc 11 | from typing import Generic, TypeVar 12 | 13 | import torch 14 | 15 | T = TypeVar("T") 16 | 17 | 18 | class ObjectPool(abc.ABC, torch.nn.Module, Generic[T]): 19 | """ 20 | Interface for TensorPool and KeyedJaggedTensorPool 21 | 22 | Defines methods for lookup, update and obtaining pool size 23 | """ 24 | 25 | @abc.abstractmethod 26 | def lookup(self, ids: torch.Tensor) -> T: 27 | pass 28 | 29 | @abc.abstractmethod 30 | def update(self, ids: torch.Tensor, values: T) -> None: 31 | pass 32 | 33 | @abc.abstractproperty 34 | def pool_size(self) -> int: 35 | pass 36 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.https://www.sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /contrib/dynamic_embedding/src/tde/details/bits_op.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | namespace tde::details { 6 | 7 | namespace bits_impl { 8 | template 9 | struct Clz { 10 | int operator()(T v) const; 11 | }; 12 | 13 | template 14 | struct Ctz { 15 | int operator()(T v) const; 16 | }; 17 | } // namespace bits_impl 18 | 19 | /** 20 | * Returns the number of leading 0-bits in t, starting at the most significant 21 | * bit position. If t is 0, the result is undefined. 22 | */ 23 | template 24 | inline int Clz(T t) { 25 | bits_impl::Clz clz; 26 | return clz(t); 27 | } 28 | 29 | /** 30 | * Returns the number of trailing 0-bits in t, starting at the least significant 31 | * bit position. If t is 0, the result is undefined. 32 | */ 33 | template 34 | inline int Ctz(T t) { 35 | bits_impl::Ctz ctz; 36 | return ctz(t); 37 | } 38 | 39 | } // namespace tde::details 40 | -------------------------------------------------------------------------------- /examples/retrieval/tests/test_two_tower_retrieval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import unittest 11 | 12 | import torch 13 | 14 | from torchrec.test_utils import skip_if_asan 15 | 16 | # @manual=//torchrec/github/examples/retrieval:two_tower_retrieval_lib 17 | from ..two_tower_retrieval import infer 18 | 19 | 20 | class InferTest(unittest.TestCase): 21 | @skip_if_asan 22 | # pyre-ignore[56] 23 | @unittest.skipIf( 24 | torch.cuda.device_count() <= 1, 25 | "Not enough GPUs, this test requires at least two GPUs", 26 | ) 27 | def test_infer_function(self) -> None: 28 | infer( 29 | embedding_dim=16, 30 | layer_sizes=[16], 31 | world_size=2, 32 | ) 33 | -------------------------------------------------------------------------------- /tools/lint/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | import os 10 | from enum import Enum 11 | from typing import NamedTuple, Optional 12 | 13 | IS_WINDOWS: bool = os.name == "nt" 14 | 15 | 16 | class LintSeverity(str, Enum): 17 | ERROR = "error" 18 | WARNING = "warning" 19 | ADVICE = "advice" 20 | DISABLED = "disabled" 21 | 22 | 23 | class LintMessage(NamedTuple): 24 | path: Optional[str] 25 | line: Optional[int] 26 | char: Optional[int] 27 | code: str 28 | severity: LintSeverity 29 | name: str 30 | original: Optional[str] 31 | replacement: Optional[str] 32 | description: Optional[str] 33 | 34 | 35 | def as_posix(name: str) -> str: 36 | return name.replace("\\", "/") if IS_WINDOWS else name 37 | -------------------------------------------------------------------------------- /contrib/dynamic_embedding/src/tde/details/redis_io.cpp: -------------------------------------------------------------------------------- 1 | #include "redis_io.h" 2 | #include "tde/details/io_registry.h" 3 | #include "tde/details/redis_io_v1.h" 4 | 5 | namespace tde::details { 6 | 7 | void RegisterRedisIO() { 8 | auto& reg = IORegistry::Instance(); 9 | 10 | { 11 | IOProvider provider{}; 12 | provider.type_ = "redis"; 13 | provider.Initialize = +[](const char* cfg) -> void* { 14 | auto opt = redis_v1::Option::Parse(cfg); 15 | return new redis_v1::RedisV1(opt); 16 | }; 17 | provider.Finalize = 18 | +[](void* inst) { delete reinterpret_cast(inst); }; 19 | provider.Pull = +[](void* inst, IOPullParameter param) { 20 | reinterpret_cast(inst)->Pull(param); 21 | }; 22 | provider.Push = +[](void* inst, IOPushParameter param) { 23 | reinterpret_cast(inst)->Push(param); 24 | }; 25 | reg.Register(provider); 26 | } 27 | } 28 | } // namespace tde::details 29 | -------------------------------------------------------------------------------- /torchrec/modules/tests/test_activation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import unittest 11 | 12 | import torch 13 | from torchrec.fx import Tracer 14 | from torchrec.modules.activation import SwishLayerNorm 15 | 16 | 17 | class TestActivation(unittest.TestCase): 18 | def test_swish_takes_float(self) -> None: 19 | m = SwishLayerNorm([3, 4]) 20 | input = torch.randn(2, 3, 4) 21 | output = m(input) 22 | norm = torch.nn.LayerNorm([3, 4]) 23 | ref_output = input * torch.sigmoid(norm(input)) 24 | self.assertTrue(torch.allclose(output, ref_output)) 25 | 26 | def test_fx_script_swish(self) -> None: 27 | m = SwishLayerNorm(10) 28 | 29 | gm = torch.fx.GraphModule(m, Tracer().trace(m)) 30 | torch.jit.script(gm) 31 | -------------------------------------------------------------------------------- /torchrec/datasets/scripts/nvt/nvt_preproc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | set -e 9 | 10 | INPUT_PATH="$1" 11 | BASE_OUTPUT_PATH="$2" 12 | BATCH_SIZE="$3" 13 | TEMP_PATH=""$BASE_OUTPUT_PATH"temp/" 14 | SRC_DIR=""$BASE_OUTPUT_PATH"criteo_preproc/" 15 | BINARY_OUTPUT_PATH=""$BASE_OUTPUT_PATH"criteo_binary/" 16 | FINAL_OUTPUT_PATH=""$BINARY_OUTPUT_PATH"split" 17 | 18 | python convert_tsv_to_parquet.py -i "$INPUT_PATH" -o "$BASE_OUTPUT_PATH" 19 | python process_criteo_parquet.py -b "$BASE_OUTPUT_PATH" -s 20 | python convert_parquet_to_binary.py --src_dir "$SRC_DIR" \ 21 | --intermediate_dir "$TEMP_PATH" \ 22 | --dst_dir "$BINARY_OUTPUT_PATH" 23 | python split_binary_dataset.py --input_path "$BINARY_OUTPUT_PATH" --output_path "$FINAL_OUTPUT_PATH" --batch_size "$BATCH_SIZE" 24 | -------------------------------------------------------------------------------- /torchrec/models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | """Torchrec Models 11 | 12 | Torchrec provides the architecture for two popular recsys models; 13 | `DeepFM `_ and `DLRM (Deep Learning Recommendation Model) 14 | `_. 15 | 16 | Along with the overall model, the individual architectures of each layer are also 17 | provided (e.g. `SparseArch`, `DenseArch`, `InteractionArch`, and `OverArch`). 18 | 19 | Examples can be found within each model. 20 | 21 | The following notation is used throughout the documentation for the models: 22 | 23 | * F: number of sparse features 24 | * D: embedding_dimension of sparse features 25 | * B: batch size 26 | * num_features: number of dense features 27 | """ 28 | 29 | from torchrec.models import deepfm, dlrm # noqa 30 | -------------------------------------------------------------------------------- /torchrec/csrc/dynamic_embedding/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | add_library(tde_cpp_objs 8 | OBJECT 9 | bind.cpp 10 | id_transformer_wrapper.cpp 11 | ps.cpp 12 | details/clz_impl.cpp 13 | details/ctz_impl.cpp 14 | details/random_bits_generator.cpp 15 | details/io_registry.cpp 16 | details/io.cpp 17 | details/notification.cpp) 18 | 19 | if (BUILD_REDIS_IO) 20 | add_subdirectory(details/redis) 21 | endif() 22 | 23 | target_include_directories(tde_cpp_objs PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../../../) 24 | target_include_directories(tde_cpp_objs PUBLIC ${TORCH_INCLUDE_DIRS}) 25 | target_link_libraries(tde_cpp_objs PUBLIC ${TORCH_LIBRARIES}) 26 | target_compile_options(tde_cpp_objs PUBLIC -fPIC) 27 | target_link_libraries(tde_cpp_objs PUBLIC ${CMAKE_DL_LIBS}) 28 | -------------------------------------------------------------------------------- /contrib/dynamic_embedding/src/tde/details/random_bits_generator_benchmark.cpp: -------------------------------------------------------------------------------- 1 | #include "benchmark/benchmark.h" 2 | #include "tde/details/random_bits_generator.h" 3 | 4 | namespace tde::details { 5 | 6 | void BMRandomBitsGenerator(benchmark::State& state) { 7 | auto n = state.range(0); 8 | auto n_bits_limit = state.range(1); 9 | BitScanner scanner(n); 10 | std::mt19937_64 engine((std::random_device())()); 11 | std::uniform_int_distribution dist(1, n_bits_limit); 12 | uint16_t n_bits = dist(engine); 13 | for (auto _ : state) { 14 | if (n_bits != 0) { 15 | scanner.ResetArray([&](tcb::span span) { 16 | for (auto& v : span) { 17 | v = engine(); 18 | } 19 | }); 20 | } else { 21 | n_bits = dist(engine); 22 | } 23 | benchmark::DoNotOptimize(scanner.IsNextNBitsAllZero(n_bits)); 24 | } 25 | } 26 | 27 | BENCHMARK(BMRandomBitsGenerator) 28 | ->ArgNames({"n", "limit"}) 29 | ->Args({1024, 32}) 30 | ->Unit(benchmark::kMillisecond) 31 | ->Iterations(1024 * 1024); 32 | 33 | } // namespace tde::details 34 | -------------------------------------------------------------------------------- /torchrec/inference/include/torchrec/inference/Validation.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #include "torchrec/inference/Types.h" 12 | 13 | namespace torchrec { 14 | 15 | // Returns whether sparse features (KeyedJaggedTensor) are valid. 16 | // Currently validates: 17 | // 1. Whether sum(lengths) == size(values) 18 | // 2. Whether there are negative values in lengths 19 | // 3. If weights is present, whether sum(lengths) == size(weights) 20 | bool validateSparseFeatures( 21 | at::Tensor& values, 22 | at::Tensor& lengths, 23 | std::optional maybeWeights = std::nullopt); 24 | 25 | // Returns whether dense features are valid. 26 | // Currently validates: 27 | // 1. Whether the size of values is divisable by batch size (request level) 28 | bool validateDenseFeatures(at::Tensor& values, size_t batchSize); 29 | 30 | } // namespace torchrec 31 | -------------------------------------------------------------------------------- /torchrec/inference/inference_legacy/include/torchrec/inference/Validation.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #include "torchrec/inference/Types.h" 12 | 13 | namespace torchrec { 14 | 15 | // Returns whether sparse features (KeyedJaggedTensor) are valid. 16 | // Currently validates: 17 | // 1. Whether sum(lengths) == size(values) 18 | // 2. Whether there are negative values in lengths 19 | // 3. If weights is present, whether sum(lengths) == size(weights) 20 | bool validateSparseFeatures( 21 | at::Tensor& values, 22 | at::Tensor& lengths, 23 | std::optional maybeWeights = std::nullopt); 24 | 25 | // Returns whether dense features are valid. 26 | // Currently validates: 27 | // 1. Whether the size of values is divisable by batch size (request level) 28 | bool validateDenseFeatures(at::Tensor& values, size_t batchSize); 29 | 30 | } // namespace torchrec 31 | -------------------------------------------------------------------------------- /contrib/dynamic_embedding/src/tde/id_transformer.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include "tde/details/id_transformer_variant.h" 5 | #include "tde/tensor_list.h" 6 | 7 | namespace tde { 8 | 9 | struct TransformResult : public torch::CustomClassHolder { 10 | TransformResult(bool success, torch::Tensor ids_to_fetch) 11 | : success_(success), ids_to_fetch_(ids_to_fetch) {} 12 | bool success_; 13 | torch::Tensor ids_to_fetch_; 14 | }; 15 | 16 | class IDTransformer : public torch::CustomClassHolder { 17 | public: 18 | IDTransformer(int64_t num_embeddings, nlohmann::json json); 19 | c10::intrusive_ptr Transform( 20 | c10::intrusive_ptr global_ids, 21 | c10::intrusive_ptr cache_ids, 22 | int64_t time); 23 | 24 | torch::Tensor Evict(int64_t num_to_evict); 25 | torch::Tensor Save(); 26 | 27 | private: 28 | std::mutex mu_; 29 | details::IDTransformer transformer_; 30 | std::vector ids_to_fetch_; 31 | int64_t time_; 32 | int64_t last_save_time_; 33 | }; 34 | 35 | } // namespace tde 36 | -------------------------------------------------------------------------------- /torchrec/distributed/benchmark/benchmark_zch/models/configs/dlrmv2.yaml: -------------------------------------------------------------------------------- 1 | dense_arch_layer_sizes: 2 | - 512 3 | - 256 4 | - 64 5 | over_arch_layer_sizes: 6 | - 512 7 | - 512 8 | - 256 9 | - 1 10 | embedding_dim: 64 11 | num_embeddings_per_feature: 12 | cat_0: 40000000 13 | cat_1: 39060 14 | cat_2: 17295 15 | cat_3: 7424 16 | cat_4: 20265 17 | cat_5: 3 18 | cat_6: 7122 19 | cat_7: 1543 20 | cat_8: 63 21 | cat_9: 40000000 22 | cat_10: 3067956 23 | cat_11: 405282 24 | cat_12: 10 25 | cat_13: 2209 26 | cat_14: 11938 27 | cat_15: 155 28 | cat_16: 4 29 | cat_17: 976 30 | cat_18: 14 31 | cat_19: 40000000 32 | cat_20: 40000000 33 | cat_21: 40000000 34 | cat_22: 590152 35 | cat_23: 12973 36 | cat_24: 108 37 | cat_25: 36 38 | embedding_module_attribute_path: "dlrm.sparse_arch.embedding_bag_collection" # the attribute path after model 39 | managed_collision_module_attribute_path: "module.dlrm.sparse_arch.embedding_bag_collection.mc_embedding_bag_collection._managed_collision_collection._managed_collision_modules" # the attribute path of managed collision module after model 40 | -------------------------------------------------------------------------------- /contrib/dynamic_embedding/src/tde/details/mixed_lfu_lru_strategy.cpp: -------------------------------------------------------------------------------- 1 | #include "mixed_lfu_lru_strategy.h" 2 | #include 3 | #include "c10/macros/Macros.h" 4 | 5 | namespace tde::details { 6 | MixedLFULRUStrategy::MixedLFULRUStrategy(uint16_t min_used_freq_power) 7 | : min_lfu_power_(min_used_freq_power), time_(new std::atomic()) {} 8 | 9 | void MixedLFULRUStrategy::UpdateTime(uint32_t time) { 10 | time_->store(time); 11 | } 12 | 13 | MixedLFULRUStrategy::lxu_record_t MixedLFULRUStrategy::Update( 14 | int64_t global_id, 15 | int64_t cache_id, 16 | std::optional val) { 17 | Record r{}; 18 | r.time_ = time_->load(); 19 | 20 | if (C10_UNLIKELY(!val.has_value())) { 21 | r.freq_power_ = min_lfu_power_; 22 | } else { 23 | auto freq_power = reinterpret_cast(&val.value())->freq_power_; 24 | bool should_carry = generator_.IsNextNBitsAllZero(freq_power); 25 | if (should_carry) { 26 | ++freq_power; 27 | } 28 | r.freq_power_ = freq_power; 29 | } 30 | return *reinterpret_cast(&r); 31 | } 32 | 33 | } // namespace tde::details 34 | -------------------------------------------------------------------------------- /torchrec/distributed/tests/test_model_parallel_gloo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | from torchrec.distributed.test_utils.test_model_parallel import ModelParallelBase 11 | from torchrec.distributed.test_utils.test_model_parallel_base import ( 12 | ModelParallelSparseOnlyBase, 13 | ModelParallelStateDictBase, 14 | ) 15 | 16 | # CPU tests for Gloo. 17 | 18 | 19 | class ModelParallelTestGloo(ModelParallelBase): 20 | def setUp(self, backend: str = "gloo") -> None: 21 | super().setUp(backend=backend) 22 | 23 | 24 | class ModelParallelStateDictTestGloo(ModelParallelStateDictBase): 25 | def setUp(self, backend: str = "gloo") -> None: 26 | super().setUp(backend=backend) 27 | 28 | 29 | class ModelParallelSparseOnlyTestGloo(ModelParallelSparseOnlyBase): 30 | def setUp(self, backend: str = "gloo") -> None: 31 | super().setUp(backend=backend) 32 | -------------------------------------------------------------------------------- /torchrec/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import torchrec.distributed # noqa 11 | import torchrec.quant # noqa 12 | from torchrec.fx import tracer # noqa 13 | from torchrec.modules.embedding_configs import ( # noqa 14 | DataType, 15 | EmbeddingBagConfig, 16 | EmbeddingConfig, 17 | PoolingType, 18 | ) 19 | from torchrec.modules.embedding_modules import ( # noqa 20 | EmbeddingBagCollection, 21 | EmbeddingBagCollectionInterface, 22 | EmbeddingCollection, 23 | ) # noqa 24 | from torchrec.sparse.jagged_tensor import ( # noqa 25 | JaggedTensor, 26 | KeyedJaggedTensor, 27 | KeyedTensor, 28 | ) 29 | from torchrec.streamable import Multistreamable, Pipelineable # noqa 30 | 31 | try: 32 | # pyre-ignore[21] 33 | # @manual=//torchrec/fb:version 34 | from .version import __version__, github_version # noqa 35 | except ImportError: 36 | pass 37 | -------------------------------------------------------------------------------- /contrib/dynamic_embedding/src/torchrec_dynamic_embedding/tensor_list.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | from typing import List 9 | 10 | import torch 11 | 12 | try: 13 | torch.ops.load_library(os.path.join(os.path.dirname(__file__), "tde_cpp.so")) 14 | except Exception as ex: 15 | print(f"File tde_cpp.so not found {ex}") 16 | 17 | 18 | __all__ = [] 19 | 20 | 21 | class TensorList: 22 | def __init__(self, tensors: List[torch.Tensor]): 23 | self.tensor_list = torch.classes.tde.TensorList() 24 | for tensor in tensors: 25 | # tensor.data will allow inplace ops during autograd. 26 | # https://discuss.pytorch.org/t/disable-in-place-correctness-version-check-any-other-workaround/90738/2 27 | self.tensor_list.append(tensor.data) 28 | 29 | def __len__(self): 30 | return len(self.tensor_list) 31 | 32 | def __getitem__(self, i): 33 | return self.tensor_list[i] 34 | -------------------------------------------------------------------------------- /torchrec/csrc/dynamic_embedding/details/redis/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | FetchContent_Declare( 8 | hiredis 9 | GIT_REPOSITORY https://github.com/redis/hiredis.git 10 | GIT_TAG 06be7ff312a78f69237e5963cc7d24bc84104d3b 11 | ) 12 | 13 | FetchContent_GetProperties(hiredis) 14 | if(NOT hiredis_POPULATED) 15 | # Do not include hiredis in install targets 16 | FetchContent_Populate(hiredis) 17 | set(DISABLE_TESTS ON CACHE BOOL "Disable tests for hiredis") 18 | add_subdirectory( 19 | ${hiredis_SOURCE_DIR} ${hiredis_BINARY_DIR} EXCLUDE_FROM_ALL) 20 | endif() 21 | 22 | add_library(redis_io SHARED redis_io.cpp) 23 | target_include_directories( 24 | redis_io PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../../../../../) 25 | target_include_directories(redis_io PUBLIC ${TORCH_INCLUDE_DIRS}) 26 | target_compile_options(redis_io PUBLIC -fPIC) 27 | target_link_libraries(redis_io PUBLIC hiredis::hiredis_static) 28 | -------------------------------------------------------------------------------- /contrib/dynamic_embedding/tools/before_linux_build.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | set -xe 9 | 10 | distro=rhel7 11 | arch=x86_64 12 | CUDA_VERSION="${CUDA_VERSION:-11.8}" 13 | 14 | CUDA_MAJOR_VERSION=$(echo "${CUDA_VERSION}" | tr '.' ' ' | awk '{print $1}') 15 | CUDA_MINOR_VERSION=$(echo "${CUDA_VERSION}" | tr '.' ' ' | awk '{print $2}') 16 | 17 | yum install -y yum-utils 18 | yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$distro/$arch/cuda-$distro.repo 19 | yum install -y \ 20 | cuda-toolkit-"${CUDA_MAJOR_VERSION}"-"${CUDA_MINOR_VERSION}" \ 21 | libcudnn8-devel 22 | ln -s cuda-"${CUDA_MAJOR_VERSION}"."${CUDA_MINOR_VERSION}" /usr/local/cuda 23 | 24 | pipx install cmake 25 | pipx install ninja 26 | python -m pip install scikit-build 27 | python -m pip install --pre torch --extra-index-url \ 28 | https://download.pytorch.org/whl/nightly/cu"${CUDA_MAJOR_VERSION}""${CUDA_MINOR_VERSION}" 29 | -------------------------------------------------------------------------------- /torchrec/distributed/benchmark/README.md: -------------------------------------------------------------------------------- 1 | # TorchRec Benchmark 2 | ## benchmark_train_pipeline usage 3 | - internal: 4 | ``` 5 | buck2 run @fbcode//mode/opt fbcode//torchrec/distributed/benchmark:benchmark_train_pipeline -- \ 6 | --yaml_config=fbcode/torchrec/distributed/benchmark/yaml/sparse_data_dist_base.yml \ 7 | --name=sparse_data_dist_base_$(hg whereami | cut -c 1-10 || echo $USER) # overrides the yaml config 8 | ``` 9 | - oss: 10 | ``` 11 | python -m torchrec.distributed.benchmark.benchmark_train_pipeline \ 12 | --yaml_config=fbcode/torchrec/distributed/benchmark/yaml/sparse_data_dist_base.yml \ 13 | --name=sparse_data_dist_base_$(git rev-parse --short HEAD || echo $USER) # overrides the yaml config 14 | ``` 15 | 16 | ## benchmark_comms usage 17 | - internal: 18 | ``` 19 | buck2 run @fbcode//mode/opt fbcode//torchrec/distributed/benchmark:benchmark_comms -- \ 20 | a2a_single --name=a2a_sync_base-$(hg whereami | cut -c 1-10) --memory_snapshot=true 21 | ``` 22 | - oss: 23 | ``` 24 | python -m torchrec.distributed.benchmark.benchmark_comms \ 25 | a2a_single --name=a2a_sync_base-$(git rev-parse --short HEAD || echo $USER) --memory_snapshot=true 26 | ``` 27 | -------------------------------------------------------------------------------- /torchrec/inference/include/torchrec/inference/TestUtils.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #include 12 | #include 13 | 14 | #include 15 | 16 | #include "torchrec/inference/JaggedTensor.h" 17 | #include "torchrec/inference/Types.h" 18 | 19 | namespace torchrec { 20 | 21 | std::shared_ptr createRequest(at::Tensor denseTensor); 22 | 23 | std::shared_ptr 24 | createRequest(size_t batchSize, size_t numFeatures, const JaggedTensor& jagged); 25 | 26 | std::shared_ptr 27 | createRequest(size_t batchSize, size_t numFeatures, at::Tensor embedding); 28 | 29 | JaggedTensor createJaggedTensor(const std::vector>& input); 30 | 31 | c10::List createIValueList( 32 | const std::vector>& input); 33 | 34 | at::Tensor createEmbeddingTensor( 35 | const std::vector>& input); 36 | 37 | } // namespace torchrec 38 | -------------------------------------------------------------------------------- /torchrec/inference/inference_legacy/include/torchrec/inference/TestUtils.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #include 12 | #include 13 | 14 | #include 15 | 16 | #include "torchrec/inference/JaggedTensor.h" 17 | #include "torchrec/inference/Types.h" 18 | 19 | namespace torchrec { 20 | 21 | std::shared_ptr createRequest(at::Tensor denseTensor); 22 | 23 | std::shared_ptr 24 | createRequest(size_t batchSize, size_t numFeatures, const JaggedTensor& jagged); 25 | 26 | std::shared_ptr 27 | createRequest(size_t batchSize, size_t numFeatures, at::Tensor embedding); 28 | 29 | JaggedTensor createJaggedTensor(const std::vector>& input); 30 | 31 | c10::List createIValueList( 32 | const std::vector>& input); 33 | 34 | at::Tensor createEmbeddingTensor( 35 | const std::vector>& input); 36 | 37 | } // namespace torchrec 38 | -------------------------------------------------------------------------------- /torchrec/distributed/benchmark/yaml/sparse_data_dist_ssd.yml: -------------------------------------------------------------------------------- 1 | # this is a very basic sparse data dist config 2 | # runs on 2 ranks, showing traces with reasonable workloads 3 | RunOptions: 4 | world_size: 2 5 | num_batches: 5 6 | num_benchmarks: 2 7 | sharding_type: table_wise 8 | profile_dir: "." 9 | name: "sparse_data_dist_base" 10 | # export_stacks: True # enable this to export stack traces 11 | PipelineConfig: 12 | pipeline: "sparse" 13 | EmbeddingTablesConfig: 14 | num_unweighted_features: 100 15 | num_weighted_features: 100 16 | embedding_feature_dim: 256 17 | additional_tables: 18 | - - name: FP16_table 19 | embedding_dim: 512 20 | num_embeddings: 100_000 21 | feature_names: ["additional_0_0"] 22 | data_type: FP16 23 | - name: large_table 24 | embedding_dim: 256 25 | num_embeddings: 1_000_000 26 | feature_names: ["additional_0_1"] 27 | - [] 28 | - - name: skipped_table 29 | embedding_dim: 128 30 | num_embeddings: 100_000 31 | feature_names: ["additional_2_1"] 32 | PlannerConfig: 33 | additional_constraints: 34 | large_table: 35 | compute_kernels: [key_value] 36 | sharding_types: [row_wise] 37 | -------------------------------------------------------------------------------- /torchrec/distributed/planner/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | """Torchrec Planner 11 | 12 | The planner provides the specifications necessary for a module to be sharded, 13 | considering the possible options to build an optimized plan. 14 | 15 | The features includes: 16 | - generating all possible sharding options. 17 | - estimating perf and storage for every shard. 18 | - estimating peak memory usage to eliminate sharding plans that might OOM. 19 | - customizability for parameter constraints, partitioning, proposers, or performance 20 | modeling. 21 | - automatically building and selecting an optimized sharding plan. 22 | """ 23 | 24 | from torchrec.distributed.planner.planners import ( # noqa # noqa 25 | EmbeddingPlannerBase, 26 | EmbeddingShardingPlanner, 27 | ) 28 | from torchrec.distributed.planner.types import ParameterConstraints, Topology # noqa 29 | from torchrec.distributed.planner.utils import bytes_to_gb, sharder_name # noqa 30 | -------------------------------------------------------------------------------- /torchrec/inference/include/torchrec/inference/ShardedTensor.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #include 12 | 13 | #include 14 | 15 | namespace torchrec { 16 | 17 | struct ShardMetadata { 18 | std::vector shard_offsets; 19 | std::vector shard_lengths; 20 | 21 | bool operator==(const ShardMetadata& other) const { 22 | return shard_offsets == other.shard_offsets && 23 | shard_lengths == other.shard_lengths; 24 | } 25 | }; 26 | 27 | struct Shard { 28 | ShardMetadata metadata; 29 | at::Tensor tensor; 30 | }; 31 | 32 | struct ShardedTensorMetadata { 33 | std::vector shards_metadata; 34 | }; 35 | 36 | struct ShardedTensor { 37 | std::vector sizes; 38 | std::vector local_shards; 39 | ShardedTensorMetadata metadata; 40 | }; 41 | 42 | struct ReplicatedTensor { 43 | ShardedTensor local_replica; 44 | int64_t local_replica_id; 45 | int64_t replica_count; 46 | }; 47 | 48 | } // namespace torchrec 49 | -------------------------------------------------------------------------------- /torchrec/distributed/benchmark/yaml/sparse_data_dist_emo.yml: -------------------------------------------------------------------------------- 1 | # this is a very basic sparse data dist config 2 | # runs on 2 ranks, showing traces with reasonable workloads 3 | RunOptions: 4 | world_size: 2 5 | num_batches: 5 6 | num_benchmarks: 2 7 | sharding_type: table_wise 8 | profile_dir: "." 9 | name: "sparse_data_dist_base" 10 | # export_stacks: True # enable this to export stack traces 11 | PipelineConfig: 12 | pipeline: "sparse" 13 | EmbeddingTablesConfig: 14 | num_unweighted_features: 100 15 | num_weighted_features: 100 16 | embedding_feature_dim: 256 17 | additional_tables: 18 | - - name: FP16_table 19 | embedding_dim: 512 20 | num_embeddings: 100_000 21 | feature_names: ["additional_0_0"] 22 | data_type: FP16 23 | - name: large_table 24 | embedding_dim: 256 25 | num_embeddings: 1_000_000 26 | feature_names: ["additional_0_1"] 27 | - [] 28 | - - name: skipped_table 29 | embedding_dim: 128 30 | num_embeddings: 100_000 31 | feature_names: ["additional_2_1"] 32 | PlannerConfig: 33 | additional_constraints: 34 | large_table: 35 | compute_kernels: [fused_uvm_caching] 36 | sharding_types: [row_wise] 37 | -------------------------------------------------------------------------------- /torchrec/inference/inference_legacy/include/torchrec/inference/ShardedTensor.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #include 12 | 13 | #include 14 | 15 | namespace torchrec { 16 | 17 | struct ShardMetadata { 18 | std::vector shard_offsets; 19 | std::vector shard_lengths; 20 | 21 | bool operator==(const ShardMetadata& other) const { 22 | return shard_offsets == other.shard_offsets && 23 | shard_lengths == other.shard_lengths; 24 | } 25 | }; 26 | 27 | struct Shard { 28 | ShardMetadata metadata; 29 | at::Tensor tensor; 30 | }; 31 | 32 | struct ShardedTensorMetadata { 33 | std::vector shards_metadata; 34 | }; 35 | 36 | struct ShardedTensor { 37 | std::vector sizes; 38 | std::vector local_shards; 39 | ShardedTensorMetadata metadata; 40 | }; 41 | 42 | struct ReplicatedTensor { 43 | ShardedTensor local_replica; 44 | int64_t local_replica_id; 45 | int64_t replica_count; 46 | }; 47 | 48 | } // namespace torchrec 49 | -------------------------------------------------------------------------------- /torchrec/inference/include/torchrec/inference/ExceptionHandler.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #include 12 | #include 13 | 14 | #include 15 | 16 | #include "torchrec/inference/Exception.h" 17 | #include "torchrec/inference/Types.h" 18 | 19 | namespace torchrec { 20 | template 21 | void handleRequestException( 22 | folly::Promise>& promise, 23 | const std::string& msg) { 24 | auto ex = folly::make_exception_wrapper(msg); 25 | auto response = std::make_unique(); 26 | response->exception = std::move(ex); 27 | promise.setValue(std::move(response)); 28 | } 29 | 30 | template 31 | void handleBatchException( 32 | std::vector& contexts, 33 | const std::string& msg) { 34 | for (auto& context : contexts) { 35 | handleRequestException(context.promise, msg); 36 | } 37 | } 38 | 39 | } // namespace torchrec 40 | -------------------------------------------------------------------------------- /torchrec/distributed/benchmark/yaml/sparse_data_dist_base.yml: -------------------------------------------------------------------------------- 1 | # this is a very basic sparse data dist config 2 | # runs on 2 ranks, showing traces with reasonable workloads 3 | RunOptions: 4 | world_size: 2 5 | num_batches: 10 6 | num_benchmarks: 1 7 | num_profiles: 1 8 | sharding_type: table_wise 9 | profile_dir: "." 10 | name: "sparse_data_dist_base" 11 | # export_stacks: True # enable this to export stack traces 12 | PipelineConfig: 13 | pipeline: "sparse" 14 | ModelInputConfig: 15 | feature_pooling_avg: 30 16 | EmbeddingTablesConfig: 17 | num_unweighted_features: 90 18 | num_weighted_features: 80 19 | embedding_feature_dim: 256 20 | additional_tables: 21 | - - name: FP16_table 22 | embedding_dim: 512 23 | num_embeddings: 100_000 24 | feature_names: ["additional_0_0"] 25 | data_type: FP16 26 | - name: large_table 27 | embedding_dim: 2048 28 | num_embeddings: 1_000_000 29 | feature_names: ["additional_0_1"] 30 | - [] 31 | - - name: skipped_table 32 | embedding_dim: 128 33 | num_embeddings: 100_000 34 | feature_names: ["additional_2_1"] 35 | PlannerConfig: 36 | additional_constraints: 37 | large_table: 38 | sharding_types: [column_wise] 39 | -------------------------------------------------------------------------------- /torchrec/inference/inference_legacy/include/torchrec/inference/ExceptionHandler.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #include 12 | #include 13 | 14 | #include 15 | 16 | #include "torchrec/inference/Exception.h" 17 | #include "torchrec/inference/Types.h" 18 | 19 | namespace torchrec { 20 | template 21 | void handleRequestException( 22 | folly::Promise>& promise, 23 | const std::string& msg) { 24 | auto ex = folly::make_exception_wrapper(msg); 25 | auto response = std::make_unique(); 26 | response->exception = std::move(ex); 27 | promise.setValue(std::move(response)); 28 | } 29 | 30 | template 31 | void handleBatchException( 32 | std::vector& contexts, 33 | const std::string& msg) { 34 | for (auto& context : contexts) { 35 | handleRequestException(context.promise, msg); 36 | } 37 | } 38 | 39 | } // namespace torchrec 40 | -------------------------------------------------------------------------------- /.github/workflows/validate-binaries.yml: -------------------------------------------------------------------------------- 1 | name: Validate binaries 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | channel: 7 | description: "Channel to use (nightly, release)" 8 | required: false 9 | type: string 10 | default: release 11 | ref: 12 | description: 'Reference to checkout, defaults to empty' 13 | default: "" 14 | required: false 15 | type: string 16 | workflow_dispatch: 17 | inputs: 18 | channel: 19 | description: "Channel to use (nightly, release, test, pypi)" 20 | required: true 21 | type: choice 22 | options: 23 | - release 24 | - nightly 25 | - test 26 | ref: 27 | description: 'Reference to checkout, defaults to empty' 28 | default: "" 29 | required: false 30 | type: string 31 | 32 | jobs: 33 | validate-binaries: 34 | uses: pytorch/test-infra/.github/workflows/validate-domain-library.yml@main 35 | with: 36 | package_type: "wheel" 37 | os: "linux" 38 | channel: ${{ inputs.channel }} 39 | repository: "meta-pytorch/torchrec" 40 | smoke_test: "source ./.github/scripts/validate_binaries.sh" 41 | with_cuda: enable 42 | with_rocm: false 43 | -------------------------------------------------------------------------------- /torchrec/csrc/dynamic_embedding/details/lxu_strategy.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | #include 11 | #include 12 | 13 | namespace torchrec { 14 | 15 | class LXUStrategy { 16 | public: 17 | LXUStrategy() = default; 18 | LXUStrategy(const LXUStrategy&) = delete; 19 | LXUStrategy(LXUStrategy&& o) noexcept = default; 20 | 21 | virtual void update_time(uint32_t time) = 0; 22 | virtual int64_t time(lxu_record_t record) = 0; 23 | 24 | virtual lxu_record_t update( 25 | int64_t global_id, 26 | int64_t cache_id, 27 | std::optional val) = 0; 28 | 29 | /** 30 | * Analysis all ids and returns the num_elems that are most need to evict. 31 | * @param iterator Returns each global_id to ExtValue pair. Returns nullopt 32 | * when at ends. 33 | * @param num_to_evict 34 | * @return 35 | */ 36 | virtual std::vector evict( 37 | iterator_t iterator, 38 | uint64_t num_to_evict) = 0; 39 | }; 40 | 41 | } // namespace torchrec 42 | -------------------------------------------------------------------------------- /torchrec/distributed/benchmark/yaml/base_pipeline_light.yml: -------------------------------------------------------------------------------- 1 | # this is a very basic sparse data dist config 2 | # runs on 2 ranks, showing traces with reasonable workloads 3 | RunOptions: 4 | world_size: 2 5 | num_batches: 10 6 | num_benchmarks: 1 7 | num_profiles: 1 8 | sharding_type: table_wise 9 | profile_dir: "." 10 | name: "base_pipeline_light" 11 | # export_stacks: True # enable this to export stack traces 12 | loglevel: "info" 13 | PipelineConfig: 14 | pipeline: "base" 15 | ModelInputConfig: 16 | feature_pooling_avg: 10 17 | EmbeddingTablesConfig: 18 | num_unweighted_features: 20 19 | num_weighted_features: 20 20 | embedding_feature_dim: 256 21 | additional_tables: 22 | - - name: FP16_table 23 | embedding_dim: 512 24 | num_embeddings: 100_000 25 | feature_names: ["additional_0_0"] 26 | data_type: FP16 27 | - name: large_table 28 | embedding_dim: 2048 29 | num_embeddings: 1_000_000 30 | feature_names: ["additional_0_1"] 31 | - [] 32 | - - name: skipped_table 33 | embedding_dim: 128 34 | num_embeddings: 100_000 35 | feature_names: ["additional_2_1"] 36 | PlannerConfig: 37 | additional_constraints: 38 | large_table: 39 | sharding_types: [column_wise] 40 | -------------------------------------------------------------------------------- /torchrec/csrc/dynamic_embedding/details/bits_op.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | #include 11 | #include 12 | 13 | namespace torchrec { 14 | 15 | namespace bits_impl { 16 | template 17 | struct Clz { 18 | int operator()(T v) const; 19 | }; 20 | 21 | template 22 | struct Ctz { 23 | int operator()(T v) const; 24 | }; 25 | } // namespace bits_impl 26 | 27 | /** 28 | * Returns the number of leading 0-bits in t, starting at the most significant 29 | * bit position. If t is 0, the result is undefined. 30 | * clz stands for counting leading zeros. 31 | */ 32 | template 33 | inline int clz(T t) { 34 | bits_impl::Clz clz; 35 | return clz(t); 36 | } 37 | 38 | /** 39 | * Returns the number of trailing 0-bits in t, starting at the least significant 40 | * bit position. If t is 0, the result is undefined. 41 | * ctz stands for counting trailing zeros. 42 | */ 43 | template 44 | inline int ctz(T t) { 45 | bits_impl::Ctz ctz; 46 | return ctz(t); 47 | } 48 | 49 | } // namespace torchrec 50 | -------------------------------------------------------------------------------- /.github/scripts/install_libs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | echo "CU_VERSION: ${CU_VERSION}" 9 | echo "CHANNEL: ${CHANNEL}" 10 | echo "CONDA_ENV: ${CONDA_ENV}" 11 | 12 | if [[ $CU_VERSION = cu* ]]; then 13 | # Setting LD_LIBRARY_PATH fixes the runtime error with fbgemm_gpu not 14 | # being able to locate libnvrtc.so 15 | echo "[NOVA] Setting LD_LIBRARY_PATH ..." 16 | conda env config vars set -p ${CONDA_ENV} \ 17 | LD_LIBRARY_PATH="/usr/local/lib:${CUDA_HOME}/lib64:${CONDA_ENV}/lib:${LD_LIBRARY_PATH}" 18 | else 19 | echo "[NOVA] Setting LD_LIBRARY_PATH ..." 20 | conda env config vars set -p ${CONDA_ENV} \ 21 | LD_LIBRARY_PATH="/usr/local/lib:${CONDA_ENV}/lib:${LD_LIBRARY_PATH}" 22 | fi 23 | 24 | if [ "$CHANNEL" = "nightly" ]; then 25 | ${CONDA_RUN} pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/"$CU_VERSION" 26 | elif [ "$CHANNEL" = "test" ]; then 27 | ${CONDA_RUN} pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/test/"$CU_VERSION" 28 | fi 29 | 30 | 31 | ${CONDA_RUN} pip install -r requirements.txt 32 | -------------------------------------------------------------------------------- /contrib/dynamic_embedding/src/tde/details/cacheline_id_transformer_benchmark.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "benchmark/benchmark.h" 3 | #include "tde/details/cacheline_id_transformer.h" 4 | 5 | namespace tde::details { 6 | 7 | static void BM_CachelineIDTransformer(benchmark::State& state) { 8 | CachelineIDTransformer transformer(2e8); 9 | torch::Tensor global_ids = torch::empty({1024, 1024}, torch::kLong); 10 | torch::Tensor cache_ids = torch::empty_like(global_ids); 11 | for (auto _ : state) { 12 | state.PauseTiming(); 13 | global_ids.random_(state.range(0), state.range(1)); 14 | state.ResumeTiming(); 15 | transformer.Transform( 16 | tcb::span{ 17 | global_ids.template data_ptr(), 18 | static_cast(global_ids.numel())}, 19 | tcb::span{ 20 | cache_ids.template data_ptr(), 21 | static_cast(cache_ids.numel())}); 22 | } 23 | } 24 | 25 | BENCHMARK(BM_CachelineIDTransformer) 26 | ->Iterations(100) 27 | ->Unit(benchmark::kMillisecond) 28 | ->ArgNames({"rand_from", "rand_to"}) 29 | ->Args({static_cast(1e10), static_cast(2e10)}) 30 | ->Args({static_cast(1e6), static_cast(2e6)}); 31 | 32 | } // namespace tde::details 33 | -------------------------------------------------------------------------------- /contrib/dynamic_embedding/src/tde/details/naive_id_transformer_benchmark.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "benchmark/benchmark.h" 3 | #include "tde/details/naive_id_transformer.h" 4 | 5 | namespace tde::details { 6 | 7 | static void BM_NaiveIDTransformer(benchmark::State& state) { 8 | using Tag = int32_t; 9 | NaiveIDTransformer transformer(2e8); 10 | torch::Tensor global_ids = torch::empty({1024, 1024}, torch::kLong); 11 | torch::Tensor cache_ids = torch::empty_like(global_ids); 12 | for (auto _ : state) { 13 | state.PauseTiming(); 14 | global_ids.random_(state.range(0), state.range(1)); 15 | state.ResumeTiming(); 16 | transformer.Transform( 17 | tcb::span{ 18 | global_ids.template data_ptr(), 19 | static_cast(global_ids.numel())}, 20 | tcb::span{ 21 | cache_ids.template data_ptr(), 22 | static_cast(cache_ids.numel())}); 23 | } 24 | } 25 | 26 | BENCHMARK(BM_NaiveIDTransformer) 27 | ->Iterations(100) 28 | ->Unit(benchmark::kMillisecond) 29 | ->ArgNames({"rand_from", "rand_to"}) 30 | ->Args({static_cast(1e10), static_cast(2e10)}) 31 | ->Args({static_cast(1e6), static_cast(2e6)}); 32 | 33 | } // namespace tde::details 34 | -------------------------------------------------------------------------------- /torchrec/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | """Torchrec Datasets 11 | 12 | Torchrec contains two popular recys datasets, the `Kaggle/Criteo Display Advertising `_ Dataset 13 | and the `MovieLens 20M `_ Dataset. 14 | 15 | Additionally, it contains a RandomDataset, which is useful to generate random data in the same format as the above. 16 | 17 | Lastly, it contains scripts and utilities for pre-processing, loading, etc. 18 | 19 | Example:: 20 | 21 | from torchrec.datasets.criteo import criteo_kaggle 22 | datapipe = criteo_terabyte( 23 | ("/home/datasets/criteo/day_0.tsv", "/home/datasets/criteo/day_1.tsv") 24 | ) 25 | datapipe = dp.iter.Batcher(datapipe, 100) 26 | datapipe = dp.iter.Collator(datapipe) 27 | batch = next(iter(datapipe)) 28 | """ 29 | 30 | import torchrec.datasets.criteo # noqa 31 | import torchrec.datasets.movielens # noqa 32 | import torchrec.datasets.random # noqa 33 | import torchrec.datasets.utils # noqa 34 | -------------------------------------------------------------------------------- /torchrec/quant/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | """Torchrec Quantization 11 | 12 | Torchrec provides a quantized version of EmbeddingBagCollection for inference. 13 | It relies on fbgemm quantized ops. 14 | This reduces the size of the model weights and speeds up model execution. 15 | 16 | Example: 17 | >>> import torch.quantization as quant 18 | >>> import torchrec.quant as trec_quant 19 | >>> import torchrec as trec 20 | >>> qconfig = quant.QConfig( 21 | >>> activation=quant.PlaceholderObserver, 22 | >>> weight=quant.PlaceholderObserver.with_args(dtype=torch.qint8), 23 | >>> ) 24 | >>> quantized = quant.quantize_dynamic( 25 | >>> module, 26 | >>> qconfig_spec={ 27 | >>> trec.EmbeddingBagCollection: qconfig, 28 | >>> }, 29 | >>> mapping={ 30 | >>> trec.EmbeddingBagCollection: trec_quant.EmbeddingBagCollection, 31 | >>> }, 32 | >>> inplace=inplace, 33 | >>> ) 34 | """ 35 | 36 | from torchrec.quant.embedding_modules import EmbeddingBagCollection # noqa 37 | -------------------------------------------------------------------------------- /torchrec/distributed/benchmark/yaml/sparse_data_dist_base_vbe.yml: -------------------------------------------------------------------------------- 1 | # this is a very basic sparse data dist config 2 | # runs on 2 ranks, showing traces with reasonable workloads 3 | RunOptions: 4 | world_size: 2 5 | batch_size: 16384 6 | num_batches: 10 7 | num_benchmarks: 1 8 | num_profiles: 1 9 | sharding_type: table_wise 10 | profile_dir: "." 11 | name: "sparse_data_dist_base" 12 | # export_stacks: True # enable this to export stack traces 13 | PipelineConfig: 14 | pipeline: "sparse" 15 | ModelInputConfig: 16 | feature_pooling_avg: 30 17 | use_variable_batch: True 18 | EmbeddingTablesConfig: 19 | num_unweighted_features: 90 20 | num_weighted_features: 80 21 | embedding_feature_dim: 256 22 | additional_tables: 23 | - - name: FP16_table 24 | embedding_dim: 512 25 | num_embeddings: 100_000 26 | feature_names: ["additional_0_0"] 27 | data_type: FP16 28 | - name: large_table 29 | embedding_dim: 2048 30 | num_embeddings: 1_000_000 31 | feature_names: ["additional_0_1"] 32 | - [] 33 | - - name: skipped_table 34 | embedding_dim: 128 35 | num_embeddings: 100_000 36 | feature_names: ["additional_2_1"] 37 | PlannerConfig: 38 | additional_constraints: 39 | large_table: 40 | sharding_types: [column_wise] 41 | -------------------------------------------------------------------------------- /torchrec/modules/tests/test_code_quality.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import inspect 11 | import sys 12 | import unittest 13 | 14 | import torch 15 | import torchrec # noqa 16 | from torchrec.linter.module_linter import MAX_NUM_ARGS_IN_MODULE_CTOR 17 | 18 | 19 | class CodeQualityTest(unittest.TestCase): 20 | def test_num_ctor_args(self) -> None: 21 | classes = inspect.getmembers(sys.modules["torchrec"], inspect.isclass) 22 | for class_name, clazz in classes: 23 | if issubclass(clazz, torch.nn.Module): 24 | num_args_excluding_self = ( 25 | len(inspect.getfullargspec(clazz.__init__).args) - 1 26 | ) 27 | self.assertLessEqual( 28 | num_args_excluding_self, 29 | MAX_NUM_ARGS_IN_MODULE_CTOR, 30 | "Modules in TorchRec can have no more than {} constructor args, but {} has {}.".format( 31 | MAX_NUM_ARGS_IN_MODULE_CTOR, class_name, num_args_excluding_self 32 | ), 33 | ) 34 | -------------------------------------------------------------------------------- /torchrec/sparse/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | """Torchrec Jagged Tensors 11 | 12 | It has 3 classes: JaggedTensor, KeyedJaggedTensor, KeyedTensor. 13 | 14 | JaggedTensor 15 | 16 | It represents an (optionally weighted) jagged tensor. A JaggedTensor is a 17 | tensor with a jagged dimension which is dimension whose slices may be of 18 | different lengths. See KeyedJaggedTensor docstring for full example and further 19 | information. 20 | 21 | KeyedJaggedTensor 22 | 23 | KeyedJaggedTensor has additional "Key" information. Keyed on first dimesion, 24 | and jagged on last dimension. Please refer to KeyedJaggedTensor docstring for full example and 25 | further information. 26 | 27 | KeyedTensor 28 | 29 | KeyedTensor holds a concatenated list of dense tensors each of which can be accessed by a key. 30 | Keyed dimension can be variable length (length_per_key). Common use cases uses include storage 31 | of pooled embeddings of different dimensions. Please refer to KeyedTensor docstring for full 32 | example and further information. 33 | """ 34 | 35 | from . import jagged_tensor # noqa 36 | -------------------------------------------------------------------------------- /torchrec/csrc/dynamic_embedding/details/io_parameter.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include 10 | 11 | namespace torchrec { 12 | 13 | using GlobalIDFetchCallback = void (*)( 14 | void* ctx, 15 | uint32_t offset, 16 | uint32_t optimizer_state, 17 | void* data, 18 | uint32_t data_len); 19 | 20 | struct IOFetchParameter { 21 | const char* table_name; 22 | uint32_t num_cols; 23 | uint32_t num_global_ids; 24 | const int64_t* col_ids; 25 | const int64_t* global_ids; 26 | uint32_t num_optimizer_states; 27 | void* on_complete_context; 28 | GlobalIDFetchCallback on_global_id_fetched; 29 | void (*on_all_fetched)(void* ctx); 30 | }; 31 | 32 | struct IOPushParameter { 33 | const char* table_name; 34 | uint32_t num_cols; 35 | uint32_t num_global_ids; 36 | const int64_t* col_ids; 37 | const int64_t* global_ids; 38 | uint32_t num_optimizer_states; 39 | const uint32_t* optimizer_state_ids; 40 | uint32_t num_offsets; 41 | const uint64_t* offsets; 42 | const void* data; 43 | void* on_complete_context; 44 | void (*on_push_complete)(void* ctx); 45 | }; 46 | 47 | } // namespace torchrec 48 | -------------------------------------------------------------------------------- /torchrec/modules/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | """Torchrec Common Modules 11 | 12 | The torchrec modules contain a collection of various modules. 13 | 14 | These modules include: 15 | - extensions of `nn.Embedding` and `nn.EmbeddingBag`, called `EmbeddingBagCollection` 16 | and `EmbeddingCollection` respectively. 17 | - established modules such as `DeepFM `_ and 18 | `CrossNet `_. 19 | - common module patterns such as `MLP` and `SwishLayerNorm`. 20 | - custom modules for TorchRec such as `PositionWeightedModule` and 21 | `LazyModuleExtensionMixin`. 22 | - `EmbeddingTower` and `EmbeddingTowerCollection`, logical "tower" of embeddings 23 | passed to provided interaction module. 24 | """ 25 | 26 | from . import ( # noqa # noqa # noqa # noqa # noqa # noqa # noqa # noqa # noqa 27 | activation, 28 | crossnet, 29 | deepfm, 30 | embedding_configs, 31 | embedding_modules, 32 | embedding_tower, 33 | feature_processor, 34 | lazy_extension, 35 | mlp, 36 | ) 37 | -------------------------------------------------------------------------------- /examples/bert4rec/models/tests/test_bert4rec.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | #!/usr/bin/env python3 11 | 12 | import unittest 13 | 14 | import torch 15 | from torchrec.sparse.jagged_tensor import KeyedJaggedTensor 16 | 17 | from ..bert4rec import BERT4Rec 18 | 19 | 20 | class BERT4RecTest(unittest.TestCase): 21 | def test_bert4rec(self) -> None: 22 | # input tensor 23 | # [2, 4], 24 | # [3, 4, 5], 25 | input_kjt = KeyedJaggedTensor.from_lengths_sync( 26 | keys=["item"], 27 | values=torch.tensor([2, 4, 3, 4, 5]), 28 | lengths=torch.tensor([2, 3]), 29 | ) 30 | device = ( 31 | torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 32 | ) 33 | input_kjt = input_kjt.to(device) 34 | bert4rec = BERT4Rec( 35 | vocab_size=6, max_len=3, emb_dim=4, nhead=4, num_layers=4, device=device 36 | ) 37 | logits = bert4rec(input_kjt) 38 | assert logits.size() == torch.Size( 39 | [input_kjt.stride(), bert4rec.max_len, bert4rec.vocab_size] 40 | ) 41 | -------------------------------------------------------------------------------- /torchrec/inference/inference_legacy/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-ignore-all-errors[0, 21] 9 | 10 | """Torchrec Inference 11 | 12 | Torchrec inference provides a Torch.Deploy based library for GPU inference. 13 | 14 | These includes: 15 | - Model packaging in Python 16 | - `PredictModule` and `PredictFactory` are the contracts between the Python model authoring and the C++ model serving. 17 | - `PredictFactoryPackager` can be used to package a PredictFactory class using torch.package. 18 | - Model serving in C++ 19 | - `BatchingQueue` is a generalized config-based request tensor batching implementation. 20 | - `GPUExecutor` handles the forward call into the inference model inside Torch.Deploy. 21 | 22 | We implemented an example of how to use this library with the TorchRec DLRM model. 23 | - `examples/dlrm/inference/dlrm_packager.py`: this demonstrates how to export the DLRM model as a torch.package. 24 | - `examples/dlrm/inference/dlrm_predict.py`: this shows how to use `PredictModule` and `PredictFactory` based on an existing model. 25 | """ 26 | 27 | from . import model_packager # noqa 28 | -------------------------------------------------------------------------------- /torchrec/inference/protos/predictor.proto: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | syntax = "proto3"; 10 | 11 | package predictor; 12 | 13 | message SparseFeatures { 14 | int32 num_features = 1; 15 | // int32: T x B 16 | bytes lengths = 2; 17 | // T x B x L (jagged) 18 | bytes values = 3; 19 | bytes weights = 4; 20 | } 21 | 22 | message FloatFeatures { 23 | int32 num_features = 1; 24 | // shape: {B} 25 | bytes values = 2; 26 | } 27 | 28 | message PredictionRequest { 29 | int32 batch_size = 1; 30 | FloatFeatures float_features = 2; 31 | SparseFeatures id_list_features = 3; 32 | SparseFeatures id_score_list_features = 4; 33 | FloatFeatures embedding_features = 5; 34 | SparseFeatures unary_features = 6; 35 | } 36 | 37 | message FloatVec { 38 | repeated float data = 1; 39 | } 40 | 41 | // TODO: See whether FloatVec can be replaced with folly::iobuf 42 | message PredictionResponse { 43 | // Task name to prediction Tensor 44 | map predictions = 1; 45 | } 46 | 47 | // The predictor service definition. Synchronous for now. 48 | service Predictor { 49 | rpc Predict(PredictionRequest) returns (PredictionResponse) {} 50 | } 51 | -------------------------------------------------------------------------------- /benchmarks/cpp/dynamic_embedding/random_bits_generator_benchmark.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | namespace torchrec { 14 | 15 | void BMRandomBitsGenerator(benchmark::State& state) { 16 | auto n = state.range(0); 17 | auto n_bits_limit = state.range(1); 18 | BitScanner scanner(n); 19 | std::mt19937_64 engine((std::random_device())()); 20 | std::uniform_int_distribution dist(1, n_bits_limit); 21 | uint16_t n_bits = dist(engine); 22 | for (auto _ : state) { 23 | if (n_bits != 0) { 24 | scanner.reset_array([&](std::span span) { 25 | for (auto& v : span) { 26 | v = engine(); 27 | } 28 | }); 29 | } else { 30 | n_bits = dist(engine); 31 | } 32 | benchmark::DoNotOptimize(scanner.is_next_n_bits_all_zero(n_bits)); 33 | } 34 | } 35 | 36 | BENCHMARK(BMRandomBitsGenerator) 37 | ->ArgNames({"n", "limit"}) 38 | ->Args({1024, 32}) 39 | ->Unit(benchmark::kMillisecond) 40 | ->Iterations(1024 * 1024); 41 | 42 | } // namespace torchrec 43 | -------------------------------------------------------------------------------- /torchrec/datasets/scripts/nvt/utils/dask.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import os 9 | import shutil 10 | 11 | import numba 12 | from dask.distributed import Client 13 | from dask_cuda import LocalCUDACluster 14 | from nvtabular.utils import device_mem_size 15 | 16 | 17 | def setup_dask(dask_workdir): 18 | if os.path.exists(dask_workdir): 19 | shutil.rmtree(dask_workdir) 20 | os.makedirs(dask_workdir) 21 | 22 | device_limit_frac = 0.8 # Spill GPU-Worker memory to host at this limit. 23 | device_pool_frac = 0.7 24 | 25 | # Use total device size to calculate device limit and pool_size 26 | device_size = device_mem_size(kind="total") 27 | device_limit = int(device_limit_frac * device_size) 28 | device_pool_size = int(device_pool_frac * device_size) 29 | 30 | cluster = LocalCUDACluster( 31 | protocol="tcp", 32 | n_workers=len(numba.cuda.gpus), 33 | CUDA_VISIBLE_DEVICES=range(len(numba.cuda.gpus)), 34 | device_memory_limit=device_limit, 35 | local_directory=dask_workdir, 36 | rmm_pool_size=(device_pool_size // 256) * 256, 37 | ) 38 | 39 | return Client(cluster) 40 | -------------------------------------------------------------------------------- /torchrec/modules/pruning_logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import logging 9 | from abc import ABC, abstractmethod 10 | from contextlib import contextmanager 11 | from dataclasses import dataclass 12 | from types import SimpleNamespace 13 | from typing import Generator, Optional 14 | 15 | logger: logging.Logger = logging.getLogger(__name__) 16 | 17 | 18 | @dataclass 19 | class PruningLogBase(object): 20 | pass 21 | 22 | 23 | class PruningLogger(ABC): 24 | @classmethod 25 | @abstractmethod 26 | @contextmanager 27 | def pruning_logger( 28 | cls, 29 | event: str, 30 | trainer: Optional[str] = None, 31 | publisher: Optional[str] = None, 32 | ) -> Generator[object, None, None]: 33 | pass 34 | 35 | 36 | class PruningLoggerDefault(PruningLogger): 37 | """ 38 | noop logger as a default 39 | """ 40 | 41 | @classmethod 42 | @contextmanager 43 | def pruning_logger( 44 | cls, 45 | event: str, 46 | trainer: Optional[str] = None, 47 | publisher: Optional[str] = None, 48 | ) -> Generator[object, None, None]: 49 | yield SimpleNamespace() 50 | -------------------------------------------------------------------------------- /docs/source/planner-api-reference.rst: -------------------------------------------------------------------------------- 1 | Planner 2 | ---------------------------------- 3 | 4 | The TorchRec Planner is responsible for determining the most performant, balanced 5 | sharding plan for distributed training and inference. 6 | 7 | The main API for generating a sharding plan is ``EmbeddingShardingPlanner.plan`` 8 | 9 | .. automodule:: torchrec.distributed.types 10 | 11 | .. autoclass:: ShardingPlan 12 | :members: 13 | 14 | .. automodule:: torchrec.distributed.planner.planners 15 | 16 | .. autoclass:: EmbeddingShardingPlanner 17 | :members: 18 | 19 | .. automodule:: torchrec.distributed.planner.enumerators 20 | 21 | .. autoclass:: EmbeddingEnumerator 22 | :members: 23 | 24 | .. automodule:: torchrec.distributed.planner.partitioners 25 | 26 | .. autoclass:: GreedyPerfPartitioner 27 | :members: 28 | 29 | 30 | .. automodule:: torchrec.distributed.planner.storage_reservations 31 | 32 | .. autoclass:: HeuristicalStorageReservation 33 | :members: 34 | 35 | .. automodule:: torchrec.distributed.planner.proposers 36 | 37 | .. autoclass:: GreedyProposer 38 | :members: 39 | 40 | 41 | .. automodule:: torchrec.distributed.planner.shard_estimators 42 | 43 | .. autoclass:: EmbeddingPerfEstimator 44 | :members: 45 | 46 | 47 | .. automodule:: torchrec.distributed.planner.shard_estimators 48 | 49 | .. autoclass:: EmbeddingStorageEstimator 50 | :members: 51 | -------------------------------------------------------------------------------- /torchrec/inference/inference_legacy/protos/predictor.proto: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | syntax = "proto3"; 10 | 11 | package predictor; 12 | 13 | message SparseFeatures { 14 | int32 num_features = 1; 15 | // int32: T x B 16 | bytes lengths = 2; 17 | // T x B x L (jagged) 18 | bytes values = 3; 19 | bytes weights = 4; 20 | } 21 | 22 | message FloatFeatures { 23 | int32 num_features = 1; 24 | // shape: {B} 25 | bytes values = 2; 26 | } 27 | 28 | message PredictionRequest { 29 | int32 batch_size = 1; 30 | FloatFeatures float_features = 2; 31 | SparseFeatures id_list_features = 3; 32 | SparseFeatures id_score_list_features = 4; 33 | FloatFeatures embedding_features = 5; 34 | SparseFeatures unary_features = 6; 35 | } 36 | 37 | message FloatVec { 38 | repeated float data = 1; 39 | } 40 | 41 | // TODO: See whether FloatVec can be replaced with folly::iobuf 42 | message PredictionResponse { 43 | // Task name to prediction Tensor 44 | map predictions = 1; 45 | } 46 | 47 | // The predictor service definition. Synchronous for now. 48 | service Predictor { 49 | rpc Predict(PredictionRequest) returns (PredictionResponse) {} 50 | } 51 | -------------------------------------------------------------------------------- /contrib/dynamic_embedding/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.11.0 FATAL_ERROR) 2 | 3 | project(torchrec_dynamic_embedding 4 | VERSION 0.0.1 5 | LANGUAGES CXX C) 6 | 7 | option(TDE_TORCH_BASE_DIR "torch python directory" "") 8 | 9 | if (NOT TDE_TORCH_BASE_DIR) 10 | message(FATAL_ERROR "TDE_TORCH_BASE_DIR must set." 11 | "Use python -c 'import torch;import os.path;print(os.path.dirname(torch.__file__))'" 12 | " to get this dir.") 13 | else() 14 | find_package(Torch REQUIRED 15 | PATHS "${TDE_TORCH_BASE_DIR}" 16 | NO_DEFAULT_PATH) 17 | endif() 18 | 19 | 20 | if(${CMAKE_SOURCE_DIR} STREQUAL ${CMAKE_CURRENT_SOURCE_DIR}) 21 | set(TDE_IS_TOP_LEVEL_PROJECT ON) 22 | else() 23 | set(TDE_IS_TOP_LEVEL_PROJECT OFF) 24 | endif() 25 | 26 | option(TDE_WITH_TESTING "Enable unittest in C++ side" ${TDE_IS_TOP_LEVEL_PROJECT}) 27 | 28 | if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") 29 | option(TDE_WITH_CXX11_ABI "GLIBCXX use c++11 ABI or not. libtorch installed by conda is not use it by default" OFF) 30 | if (TDE_WITH_CXX11_ABI) 31 | add_definitions("-D_GLIBCXX_USE_CXX11_ABI=1") 32 | else() 33 | add_definitions("-D_GLIBCXX_USE_CXX11_ABI=0") 34 | endif() 35 | endif() 36 | 37 | if (TDE_WITH_TESTING) 38 | enable_testing() 39 | add_subdirectory(tests/memory_io) 40 | endif() 41 | add_subdirectory(src) 42 | -------------------------------------------------------------------------------- /torchrec/csrc/dynamic_embedding/details/io_registry.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | namespace torchrec { 19 | 20 | struct IOProvider { 21 | const char* type; 22 | void* (*initialize)(const char* cfg); 23 | void (*fetch)(void* instance, IOFetchParameter cfg); 24 | void (*push)(void* instance, IOPushParameter cfg); 25 | void (*finalize)(void*); 26 | }; 27 | 28 | class IORegistry { 29 | public: 30 | void register_provider(IOProvider provider); 31 | void register_plugin(const char* filename); 32 | [[nodiscard]] IOProvider resolve(const std::string& name) const; 33 | 34 | static IORegistry& Instance(); 35 | 36 | private: 37 | IORegistry() = default; 38 | ska::flat_hash_map providers_; 39 | struct DLCloser { 40 | void operator()(void* ptr) const; 41 | }; 42 | 43 | using DLPtr = std::unique_ptr; 44 | std::vector dls_; 45 | }; 46 | 47 | } // namespace torchrec 48 | -------------------------------------------------------------------------------- /torchrec/csrc/dynamic_embedding/details/bitmap.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | #include 11 | #include 12 | 13 | namespace torchrec { 14 | 15 | /** 16 | * Bitmap 17 | * 18 | * A bitmap for recording whether num_bits of slots are 19 | * occupied or free. 20 | */ 21 | template 22 | struct Bitmap { 23 | explicit Bitmap(int64_t num_bits); 24 | Bitmap(const Bitmap&) = delete; 25 | Bitmap(Bitmap&&) noexcept = default; 26 | 27 | /** 28 | * Returns the position of the next free slot. 29 | * If the bitmap is full, return `num_total_bits_`. 30 | */ 31 | int64_t next_free_bit(); 32 | 33 | /** 34 | * Set the slot of position `offset` to free. 35 | */ 36 | void free_bit(int64_t offset); 37 | 38 | /** 39 | * Returns if all slots in the bitmap is occupied. 40 | */ 41 | bool full() const; 42 | 43 | static constexpr int64_t num_bits_per_value = sizeof(T) * 8; 44 | 45 | const int64_t num_total_bits_; 46 | const int64_t num_values_; 47 | std::unique_ptr values_; 48 | 49 | int64_t next_free_bit_; 50 | }; 51 | 52 | } // namespace torchrec 53 | 54 | #include 55 | -------------------------------------------------------------------------------- /torchrec/ir/types.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | #!/usr/bin/env python3 11 | 12 | import abc 13 | from typing import Any, Dict, List, Optional 14 | 15 | import torch 16 | 17 | from torch import nn 18 | 19 | 20 | class SerializerInterface(abc.ABC): 21 | """ 22 | Interface for Serializer classes for torch.export IR. 23 | """ 24 | 25 | @classmethod 26 | @property 27 | def module_to_serializer_cls(cls) -> Dict[str, Any]: 28 | raise NotImplementedError 29 | 30 | @classmethod 31 | @abc.abstractmethod 32 | def encapsulate_module(cls, module: nn.Module) -> List[str]: 33 | # Take the eager embedding module and encapsulate the module, including serialization 34 | # and meta_forward-swapping, then returns a list of children (fqns) which needs further encapsulation 35 | raise NotImplementedError 36 | 37 | @classmethod 38 | @abc.abstractmethod 39 | def decapsulate_module( 40 | cls, module: nn.Module, device: Optional[torch.device] = None 41 | ) -> nn.Module: 42 | # Take the eager embedding module and decapsulate it by removing serialization and meta_forward-swapping 43 | raise NotImplementedError 44 | -------------------------------------------------------------------------------- /contrib/dynamic_embedding/src/tde/details/mixed_lfu_lru_strategy_benchmark.cpp: -------------------------------------------------------------------------------- 1 | #include "benchmark/benchmark.h" 2 | #include "tde/details/mixed_lfu_lru_strategy.h" 3 | 4 | namespace tde::details { 5 | void BM_MixedLFULRUStrategy(benchmark::State& state) { 6 | size_t num_ext_values = state.range(0); 7 | std::vector ext_values(num_ext_values); 8 | 9 | MixedLFULRUStrategy strategy; 10 | for (auto& v : ext_values) { 11 | v = strategy.Update(0, 0, std::nullopt); 12 | } 13 | 14 | size_t num_elems = state.range(1); 15 | std::default_random_engine engine((std::random_device())()); 16 | size_t time = 0; 17 | for (auto _ : state) { 18 | state.PauseTiming(); 19 | std::vector offsets; 20 | offsets.reserve(num_elems); 21 | for (size_t i = 0; i < num_elems; ++i) { 22 | std::uniform_int_distribution dist(0, num_elems - 1); 23 | offsets.emplace_back(dist(engine)); 24 | } 25 | state.ResumeTiming(); 26 | 27 | ++time; 28 | strategy.UpdateTime(time); 29 | for (auto& v : offsets) { 30 | ext_values[v] = strategy.Update(0, 0, ext_values[v]); 31 | } 32 | } 33 | } 34 | 35 | BENCHMARK(BM_MixedLFULRUStrategy) 36 | ->ArgNames({"num_ext_values", "num_elems_per_iter"}) 37 | ->Args({30000000, 1024 * 1024}) 38 | ->Args({300000000, 1024 * 1024}) 39 | ->Unit(benchmark::kMillisecond) 40 | ->Iterations(100); 41 | 42 | } // namespace tde::details 43 | -------------------------------------------------------------------------------- /torchrec/datasets/scripts/nvt/utils/criteo_constant.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # This file is needed because NvTabulr will create so many ghots processes if we 9 | # import from TorchRec criteo datasets 10 | 11 | from typing import List 12 | 13 | 14 | FREQUENCY_THRESHOLD = 3 15 | INT_FEATURE_COUNT = 13 16 | CAT_FEATURE_COUNT = 26 17 | DAYS = 24 18 | DEFAULT_LABEL_NAME = "label" 19 | DEFAULT_INT_NAMES: List[str] = [f"int_{idx}" for idx in range(INT_FEATURE_COUNT)] 20 | DEFAULT_CAT_NAMES: List[str] = [f"cat_{idx}" for idx in range(CAT_FEATURE_COUNT)] 21 | DEFAULT_COLUMN_NAMES: List[str] = [ 22 | DEFAULT_LABEL_NAME, 23 | *DEFAULT_INT_NAMES, 24 | *DEFAULT_CAT_NAMES, 25 | ] 26 | NUM_EMBEDDINGS_PER_FEATURE = [ 27 | 40000000, 28 | 39060, 29 | 17295, 30 | 7424, 31 | 20265, 32 | 3, 33 | 7122, 34 | 1543, 35 | 63, 36 | 40000000, 37 | 3067956, 38 | 405282, 39 | 10, 40 | 2209, 41 | 11938, 42 | 155, 43 | 4, 44 | 976, 45 | 14, 46 | 40000000, 47 | 40000000, 48 | 40000000, 49 | 590152, 50 | 12973, 51 | 108, 52 | 36, 53 | ] 54 | 55 | NUM_EMBEDDINGS_PER_FEATURE_DICT = dict( 56 | zip(DEFAULT_CAT_NAMES, NUM_EMBEDDINGS_PER_FEATURE) 57 | ) 58 | -------------------------------------------------------------------------------- /examples/nvt_dataloader/aws_component.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import os 9 | 10 | import torchx.specs as specs 11 | from torchx.components.dist import ddp 12 | 13 | 14 | def run_dlrm_main(num_trainers: int = 8, *script_args: str) -> specs.AppDef: 15 | """ 16 | Args: 17 | num_trainers: The number of trainers to use. 18 | script_args: A variable number of parameters to provide dlrm_main.py. 19 | """ 20 | cwd = os.getcwd() 21 | entrypoint = os.path.join(cwd, "train_torchrec.py") 22 | 23 | user = os.environ.get("USER") 24 | image = f"/data/home/{user}" 25 | 26 | if num_trainers > 8 and num_trainers % 8 != 0: 27 | raise ValueError( 28 | "Trainer jobs spanning multiple hosts must be in multiples of 8." 29 | ) 30 | nproc_per_node = 8 if num_trainers >= 8 else num_trainers 31 | num_replicas = max(num_trainers // 8, 1) 32 | 33 | return ddp( 34 | *script_args, 35 | name="train_dlrm", 36 | image=image, 37 | # AWS p4d instance (https://aws.amazon.com/ec2/instance-types/p4/). 38 | cpu=96, 39 | gpu=8, 40 | memMB=-1, 41 | script=entrypoint, 42 | j=f"{num_replicas}x{nproc_per_node}", 43 | ) 44 | -------------------------------------------------------------------------------- /torchrec/distributed/model_tracker/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | """Torchrec Model Tracker 11 | 12 | The model tracker module provides functionality to track and retrieve unique IDs and 13 | embeddings for supported modules during training. This is useful for identifying and 14 | retrieving the latest delta or unique rows for a model, which can help compute topk 15 | or to stream updated embeddings from predictors to trainers during online training. 16 | 17 | Key features include: 18 | - Tracking unique IDs and embeddings for supported modules 19 | - Support for multiple consumers with independent tracking 20 | - Configurable tracking modes (ID_ONLY, EMBEDDING) 21 | - Compaction of tracked data to reduce memory usage 22 | """ 23 | 24 | from torchrec.distributed.model_tracker.delta_store import DeltaStore # noqa 25 | from torchrec.distributed.model_tracker.model_delta_tracker import ( 26 | ModelDeltaTracker, # noqa 27 | SUPPORTED_MODULES, # noqa 28 | ) 29 | from torchrec.distributed.model_tracker.types import ( 30 | IndexedLookup, # noqa 31 | ModelTrackerConfigs, # noqa 32 | Trackers, # noqa 33 | TrackingMode, # noqa 34 | UniqueRows, # noqa 35 | UpdateMode, # noqa 36 | ) 37 | -------------------------------------------------------------------------------- /examples/golden_training/tests/test_train_dlrm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import os 11 | import tempfile 12 | import unittest 13 | import uuid 14 | 15 | from torch.distributed.launcher.api import elastic_launch, LaunchConfig 16 | from torchrec.test_utils import skip_if_asan 17 | 18 | from ..train_dlrm import train 19 | 20 | 21 | class TrainTest(unittest.TestCase): 22 | @classmethod 23 | def _run_train(cls) -> None: 24 | train( 25 | embedding_dim=16, 26 | num_iterations=10, 27 | ) 28 | 29 | @skip_if_asan 30 | def test_train_function(self) -> None: 31 | with tempfile.TemporaryDirectory() as tmpdir: 32 | lc = LaunchConfig( 33 | min_nodes=1, 34 | max_nodes=1, 35 | nproc_per_node=2, 36 | run_id=str(uuid.uuid4()), 37 | rdzv_backend="c10d", 38 | rdzv_endpoint=os.path.join(tmpdir, "rdzv"), 39 | rdzv_configs={"store_type": "file"}, 40 | start_method="spawn", 41 | monitor_interval=1, 42 | max_restarts=0, 43 | ) 44 | 45 | elastic_launch(config=lc, entrypoint=self._run_train)() 46 | -------------------------------------------------------------------------------- /torchrec/sparse/tensor_dict.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | from typing import List, Optional 11 | 12 | import torch 13 | from tensordict import TensorDict 14 | 15 | from torchrec.sparse.jagged_tensor import KeyedJaggedTensor 16 | 17 | 18 | def maybe_td_to_kjt( 19 | features: KeyedJaggedTensor, keys: Optional[List[str]] = None 20 | ) -> KeyedJaggedTensor: 21 | if torch.jit.is_scripting(): 22 | assert isinstance(features, KeyedJaggedTensor) 23 | return features 24 | if isinstance(features, TensorDict): 25 | if keys is None: 26 | keys = list(features.keys()) 27 | values = torch.cat([features[key]._values for key in keys], dim=0) 28 | lengths = torch.cat( 29 | [ 30 | ( 31 | (features[key]._lengths) 32 | if features[key]._lengths is not None 33 | else torch.diff(features[key]._offsets) 34 | ) 35 | for key in keys 36 | ], 37 | dim=0, 38 | ) 39 | return KeyedJaggedTensor( 40 | keys=keys, 41 | values=values, 42 | lengths=lengths, 43 | ) 44 | else: 45 | return features 46 | -------------------------------------------------------------------------------- /torchrec/inference/inference_legacy/tests/ValidationTest.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "torchrec/inference/Validation.h" 10 | 11 | #include 12 | #include 13 | 14 | TEST(ValidationTest, validateSparseFeatures) { 15 | auto values = at::tensor({1, 2, 3, 4}); 16 | auto lengths = at::tensor({1, 1, 1, 1}); 17 | auto weights = at::tensor({.1, .2, .3, .4}); 18 | 19 | // pass 1D 20 | EXPECT_TRUE(torchrec::validateSparseFeatures(values, lengths, weights)); 21 | 22 | // pass 2D 23 | lengths.reshape({2, 2}); 24 | EXPECT_TRUE(torchrec::validateSparseFeatures(values, lengths, weights)); 25 | 26 | // fail 1D 27 | auto invalidLengths = at::tensor({1, 2, 1, 1}); 28 | EXPECT_FALSE( 29 | torchrec::validateSparseFeatures(values, invalidLengths, weights)); 30 | 31 | // fail 2D 32 | invalidLengths.reshape({2, 2}); 33 | EXPECT_FALSE( 34 | torchrec::validateSparseFeatures(values, invalidLengths, weights)); 35 | } 36 | 37 | TEST(ValidationTest, validateDenseFeatures) { 38 | auto values = at::tensor({1, 2, 3, 4}); 39 | EXPECT_TRUE(torchrec::validateDenseFeatures(values, 1)); 40 | EXPECT_TRUE(torchrec::validateDenseFeatures(values, 4)); 41 | EXPECT_FALSE(torchrec::validateDenseFeatures(values, 3)); 42 | } 43 | -------------------------------------------------------------------------------- /torchrec/distributed/sharding/twcw_sharding.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | from typing import Dict, List, Optional 11 | 12 | import torch 13 | from torchrec.distributed.embedding_sharding import EmbeddingShardingInfo 14 | from torchrec.distributed.sharding.cw_sharding import CwPooledEmbeddingSharding 15 | from torchrec.distributed.types import QuantizedCommCodecs, ShardingEnv 16 | 17 | 18 | class TwCwPooledEmbeddingSharding(CwPooledEmbeddingSharding): 19 | """ 20 | Shards embedding bags table-wise column-wise, i.e.. a given embedding table is 21 | partitioned along its columns and the table slices are placed on all ranks 22 | within a host group. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | sharding_infos: List[EmbeddingShardingInfo], 28 | env: ShardingEnv, 29 | device: Optional[torch.device] = None, 30 | permute_embeddings: bool = False, 31 | qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, 32 | ) -> None: 33 | super().__init__( 34 | sharding_infos, 35 | env, 36 | device, 37 | permute_embeddings=permute_embeddings, 38 | qcomm_codecs_registry=qcomm_codecs_registry, 39 | ) 40 | -------------------------------------------------------------------------------- /torchrec/ir/schema.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | from dataclasses import dataclass 11 | from typing import List, Optional, Tuple 12 | 13 | from torchrec.modules.embedding_configs import DataType, PoolingType 14 | 15 | 16 | # Same as EmbeddingBagConfig but serializable 17 | @dataclass 18 | class EmbeddingBagConfigMetadata: 19 | num_embeddings: int 20 | embedding_dim: int 21 | name: str 22 | data_type: DataType 23 | feature_names: List[str] 24 | weight_init_max: Optional[float] 25 | weight_init_min: Optional[float] 26 | need_pos: bool 27 | pooling: PoolingType 28 | 29 | 30 | @dataclass 31 | class EBCMetadata: 32 | tables: List[EmbeddingBagConfigMetadata] 33 | is_weighted: bool 34 | device: Optional[str] 35 | 36 | 37 | @dataclass 38 | class FPEBCMetadata: 39 | is_fp_collection: bool 40 | features: List[str] 41 | 42 | 43 | @dataclass 44 | class PositionWeightedModuleMetadata: 45 | max_feature_length: int 46 | 47 | 48 | @dataclass 49 | class PositionWeightedModuleCollectionMetadata: 50 | max_feature_lengths: List[Tuple[str, int]] 51 | 52 | 53 | @dataclass 54 | class KTRegroupAsDictMetadata: 55 | groups: List[List[str]] 56 | keys: List[str] 57 | emb_dtype: Optional[str] 58 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | cmake_minimum_required(VERSION 3.11.0 FATAL_ERROR) 8 | 9 | project(TorchRec 10 | LANGUAGES CXX C) 11 | 12 | find_package(Torch REQUIRED) 13 | 14 | set(CMAKE_CXX_STANDARD 20) 15 | 16 | include(FetchContent) 17 | 18 | option(BUILD_TEST "Build C++ test binaries (need gtest and gbenchmark)" OFF) 19 | 20 | add_definitions("-D_GLIBCXX_USE_CXX11_ABI=0") 21 | 22 | add_subdirectory(torchrec/csrc) 23 | 24 | if (BUILD_TEST) 25 | FetchContent_Declare(googletest 26 | GIT_REPOSITORY https://github.com/google/googletest.git 27 | GIT_TAG v1.12.0 28 | ) 29 | FetchContent_Declare(google_benchmark 30 | GIT_REPOSITORY https://github.com/google/benchmark.git 31 | GIT_TAG v1.5.6 32 | ) 33 | 34 | # We will not need to test benchmark lib itself. 35 | set(BENCHMARK_ENABLE_TESTING OFF CACHE BOOL "Disable benchmark testing as we don't need it.") 36 | # We will not need to install benchmark since we link it statically. 37 | set(BENCHMARK_ENABLE_INSTALL OFF CACHE BOOL "Disable benchmark install to avoid overwriting vendor install.") 38 | FetchContent_MakeAvailable(googletest google_benchmark) 39 | 40 | enable_testing() 41 | add_subdirectory(benchmarks/cpp) 42 | add_subdirectory(test/cpp) 43 | endif() 44 | -------------------------------------------------------------------------------- /benchmarks/cpp/dynamic_embedding/naive_id_transformer_benchmark.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | namespace torchrec { 14 | 15 | static void BM_NaiveIDTransformer(benchmark::State& state) { 16 | NaiveIDTransformer transformer(2e8); 17 | torch::Tensor global_ids = torch::empty({1024, 1024}, torch::kLong); 18 | torch::Tensor cache_ids = torch::empty_like(global_ids); 19 | for (auto _ : state) { 20 | state.PauseTiming(); 21 | global_ids.random_(state.range(0), state.range(1)); 22 | state.ResumeTiming(); 23 | transformer.transform( 24 | std::span{ 25 | global_ids.template data_ptr(), 26 | static_cast(global_ids.numel())}, 27 | std::span{ 28 | cache_ids.template data_ptr(), 29 | static_cast(cache_ids.numel())}); 30 | } 31 | } 32 | 33 | BENCHMARK(BM_NaiveIDTransformer) 34 | ->Iterations(100) 35 | ->Unit(benchmark::kMillisecond) 36 | ->ArgNames({"rand_from", "rand_to"}) 37 | ->Args({static_cast(1e10), static_cast(2e10)}) 38 | ->Args({static_cast(1e6), static_cast(2e6)}); 39 | 40 | } // namespace torchrec 41 | -------------------------------------------------------------------------------- /torchrec/distributed/tests/test_awaitable.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import unittest 11 | 12 | import torch 13 | from torchrec.distributed.types import Awaitable 14 | 15 | 16 | class AwaitableInstance(Awaitable[torch.Tensor]): 17 | def __init__(self) -> None: 18 | super().__init__() 19 | 20 | def _wait_impl(self) -> torch.Tensor: 21 | return torch.FloatTensor([1.0, 2.0, 3.0]) 22 | 23 | 24 | class AwaitableTests(unittest.TestCase): 25 | def test_callback(self) -> None: 26 | awaitable = AwaitableInstance() 27 | # pyre-fixme[6]: For 1st argument expected `(Tensor) -> Tensor` but got 28 | # `(ret: Any) -> int`. 29 | awaitable.callbacks.append(lambda ret: 2 * ret) 30 | self.assertTrue( 31 | torch.allclose(awaitable.wait(), torch.FloatTensor([2.0, 4.0, 6.0])) 32 | ) 33 | 34 | def test_callback_chained(self) -> None: 35 | awaitable = AwaitableInstance() 36 | # pyre-fixme[6]: For 1st argument expected `(Tensor) -> Tensor` but got 37 | # `(ret: Any) -> int`. 38 | awaitable.callbacks.append(lambda ret: 2 * ret) 39 | awaitable.callbacks.append(lambda ret: ret**2) 40 | self.assertTrue( 41 | torch.allclose(awaitable.wait(), torch.FloatTensor([4.0, 16.0, 36.0])) 42 | ) 43 | -------------------------------------------------------------------------------- /torchrec/distributed/composable/tests/test_table_batched_embedding_slice.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import copy 11 | import unittest 12 | 13 | import torch 14 | 15 | from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( 16 | DenseTableBatchedEmbeddingBagsCodegen, 17 | ) 18 | from torchrec.distributed.composable.table_batched_embedding_slice import ( 19 | TableBatchedEmbeddingSlice, 20 | ) 21 | 22 | 23 | class TestTableBatchedEmbeddingSlice(unittest.TestCase): 24 | def test_is_view(self) -> None: 25 | device = "cpu" if not torch.cuda.is_available() else "cuda" 26 | emb = DenseTableBatchedEmbeddingBagsCodegen( 27 | [(2, 4), (2, 4)], use_cpu=device == "cpu" 28 | ) 29 | first_table = TableBatchedEmbeddingSlice(emb.weights, 0, 8, 2, 4) 30 | self.assertEqual(first_table.data_ptr(), emb.weights.data_ptr()) 31 | 32 | def test_copy(self) -> None: 33 | device = "cpu" if not torch.cuda.is_available() else "cuda" 34 | emb = DenseTableBatchedEmbeddingBagsCodegen( 35 | [(2, 4), (2, 4)], use_cpu=device == "cpu" 36 | ) 37 | first_table = TableBatchedEmbeddingSlice(emb.weights, 0, 8, 2, 4) 38 | copied = copy.deepcopy(first_table) 39 | self.assertNotEqual(first_table.data_ptr(), copied.data_ptr()) 40 | -------------------------------------------------------------------------------- /examples/retrieval/tests/test_two_tower_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import os 11 | import tempfile 12 | import unittest 13 | import uuid 14 | 15 | from torch.distributed.launcher.api import elastic_launch, LaunchConfig 16 | from torchrec.test_utils import skip_if_asan 17 | 18 | # @manual=//torchrec/github/examples/retrieval:two_tower_train_lib 19 | from ..two_tower_train import train 20 | 21 | 22 | class TrainTest(unittest.TestCase): 23 | @classmethod 24 | def _run_train(cls) -> None: 25 | train( 26 | embedding_dim=16, 27 | layer_sizes=[16], 28 | num_iterations=10, 29 | ) 30 | 31 | @skip_if_asan 32 | def test_train_function(self) -> None: 33 | with tempfile.TemporaryDirectory() as tmpdir: 34 | lc = LaunchConfig( 35 | min_nodes=1, 36 | max_nodes=1, 37 | nproc_per_node=2, 38 | run_id=str(uuid.uuid4()), 39 | rdzv_backend="c10d", 40 | rdzv_endpoint=os.path.join(tmpdir, "rdzv"), 41 | rdzv_configs={"store_type": "file"}, 42 | start_method="spawn", 43 | monitor_interval=1, 44 | max_restarts=0, 45 | ) 46 | 47 | elastic_launch(config=lc, entrypoint=self._run_train)() 48 | -------------------------------------------------------------------------------- /torchrec/optim/fused.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import abc 11 | from typing import Any 12 | 13 | from torch import optim 14 | from torchrec.optim.keyed import KeyedOptimizer 15 | 16 | 17 | class FusedOptimizer(KeyedOptimizer, abc.ABC): 18 | """ 19 | Assumes that weight update is done during backward pass, 20 | thus step() is a no-op. 21 | """ 22 | 23 | @abc.abstractmethod 24 | # pyre-ignore [2] 25 | def step(self, closure: Any = None) -> None: ... 26 | 27 | @abc.abstractmethod 28 | def zero_grad(self, set_to_none: bool = False) -> None: ... 29 | 30 | def __repr__(self) -> str: 31 | return optim.Optimizer.__repr__(self) 32 | 33 | 34 | class EmptyFusedOptimizer(FusedOptimizer): 35 | """ 36 | Fused Optimizer class with no-op step and no parameters to optimize over 37 | """ 38 | 39 | def __init__(self) -> None: 40 | super().__init__({}, {}, {}) 41 | 42 | # pyre-ignore 43 | def step(self, closure: Any = None) -> None: 44 | pass 45 | 46 | def zero_grad(self, set_to_none: bool = False) -> None: 47 | pass 48 | 49 | 50 | class FusedOptimizerModule(abc.ABC): 51 | """ 52 | Module, which does weight update during backward pass. 53 | """ 54 | 55 | @property 56 | @abc.abstractmethod 57 | def fused_optimizer(self) -> KeyedOptimizer: ... 58 | -------------------------------------------------------------------------------- /torchrec/distributed/train_pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | 11 | from torchrec.distributed.train_pipeline.pipeline_context import ( # noqa 12 | In, 13 | Out, 14 | TrainPipelineContext, 15 | ) 16 | from torchrec.distributed.train_pipeline.pipeline_stage import ( # noqa 17 | SparseDataDistUtil, # noqa 18 | StageOut, # noqa 19 | ) 20 | from torchrec.distributed.train_pipeline.tracing import ( # noqa 21 | ArgInfoStepFactory, # noqa 22 | Tracer, # noqa 23 | ) 24 | from torchrec.distributed.train_pipeline.train_pipelines import ( # noqa 25 | EvalPipelineSparseDist, # noqa 26 | PrefetchTrainPipelineSparseDist, # noqa 27 | StagedTrainPipeline, # noqa 28 | TorchCompileConfig, # noqa 29 | TrainPipeline, # noqa 30 | TrainPipelineBase, # noqa 31 | TrainPipelineFusedSparseDist, # noqa 32 | TrainPipelinePT2, # noqa 33 | TrainPipelineSparseDist, # noqa 34 | TrainPipelineSparseDistCompAutograd, # noqa 35 | ) 36 | from torchrec.distributed.train_pipeline.types import ArgInfo, CallArgs # noqa 37 | from torchrec.distributed.train_pipeline.utils import ( # noqa 38 | _override_input_dist_forwards, # noqa 39 | _rewrite_model, # noqa 40 | _start_data_dist, # noqa 41 | _to_device, # noqa 42 | _wait_for_batch, # noqa 43 | DataLoadingThread, # noqa 44 | ) 45 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) Meta Platforms, Inc. and affiliates. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /torchrec/quant/tests/test_tensor_types.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | #!/usr/bin/env python3 11 | 12 | import unittest 13 | 14 | import torch 15 | 16 | from torchrec.tensor_types import UInt2Tensor, UInt4Tensor 17 | 18 | 19 | class QuantUtilsTest(unittest.TestCase): 20 | # pyre-ignore 21 | @unittest.skipIf( 22 | torch.cuda.device_count() <= 0, 23 | "Not enough GPUs available", 24 | ) 25 | def test_uint42_tensor(self) -> None: 26 | t_u8 = torch.tensor( 27 | [ 28 | [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], 29 | [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], 30 | [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], 31 | ], 32 | dtype=torch.uint8, 33 | ) 34 | t_u4 = UInt4Tensor(t_u8) 35 | t_u4.detach() 36 | 37 | t_u4.to(torch.device("cuda")) 38 | assert torch.equal(t_u4.view(torch.uint8), t_u8) 39 | t_u2 = UInt2Tensor(t_u8) 40 | t_u2.to(torch.device("cuda")) 41 | assert torch.equal(t_u2.view(torch.uint8), t_u8) 42 | 43 | for t in [t_u4[:, :8], t_u4[:, 8:]]: 44 | assert t.size(1) == 8 45 | t_u4[:, :8].copy_(t_u4[:, 8:]) 46 | 47 | for t in [t_u2[:, 4:8], t_u2[:, 8:12]]: 48 | assert t.size(1) == 4 49 | 50 | t_u2[:, 4:8].copy_(t_u2[:, 8:12]) 51 | -------------------------------------------------------------------------------- /torchrec/inference/inference_legacy/tests/test_modules.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # pyre-strict 4 | 5 | # Copyright (c) Meta Platforms, Inc. and affiliates. 6 | # All rights reserved. 7 | # 8 | # This source code is licensed under the BSD-style license found in the 9 | # LICENSE file in the root directory of this source tree. 10 | 11 | #!/usr/bin/env python3 12 | # @nolint 13 | 14 | import unittest 15 | 16 | from torchrec.distributed.test_utils.infer_utils import TorchTypesModelInputWrapper 17 | from torchrec.distributed.test_utils.test_model import TestSparseNN 18 | from torchrec.inference.modules import quantize_inference_model, shard_quant_model 19 | from torchrec.modules.embedding_configs import EmbeddingBagConfig 20 | 21 | 22 | class EagerModelProcessingTests(unittest.TestCase): 23 | def test_quantize_shard_cuda(self) -> None: 24 | tables = [ 25 | EmbeddingBagConfig( 26 | num_embeddings=10, 27 | embedding_dim=4, 28 | name="table_" + str(i), 29 | feature_names=["feature_" + str(i)], 30 | ) 31 | for i in range(10) 32 | ] 33 | 34 | model = TorchTypesModelInputWrapper( 35 | TestSparseNN( 36 | tables=tables, 37 | ) 38 | ) 39 | 40 | quantized_model = quantize_inference_model(model) 41 | sharded_model, _ = shard_quant_model(quantized_model) 42 | 43 | # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `sparse`. 44 | sharded_qebc = sharded_model._module.sparse.ebc 45 | self.assertEqual(len(sharded_qebc.tbes), 1) 46 | -------------------------------------------------------------------------------- /contrib/dynamic_embedding/src/tde/details/url_test.cpp: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "tde/details/url.h" 3 | 4 | namespace tde::details::url_parser::rules { 5 | 6 | TEST(TDE, url_token) { 7 | auto ipt = lexy::string_input(std::string_view("%61")); 8 | auto parse = lexy::parse(ipt, lexy_ext::report_error); 9 | ASSERT_TRUE(parse.has_value()); 10 | ASSERT_EQ('a', parse.value()); 11 | } 12 | 13 | TEST(TDE, url_string) { 14 | auto ipt = lexy::string_input(std::string_view("%61bc")); 15 | auto parse = lexy::parse(ipt, lexy_ext::report_error); 16 | ASSERT_TRUE(parse.has_value()); 17 | ASSERT_EQ("abc", parse.value()); 18 | } 19 | 20 | TEST(TDE, url_normal) { 21 | auto ipt = lexy::string_input(std::string_view("a@")); 22 | auto parse = lexy::parse(ipt, lexy_ext::report_error); 23 | ASSERT_TRUE(parse.has_value()); 24 | ASSERT_EQ("a", parse.value().username_); 25 | ASSERT_FALSE(parse.value().password_.has_value()); 26 | // ASSERT_EQ("abc", parse.value()); 27 | } 28 | 29 | TEST(TDE, url_host) { 30 | auto ipt = lexy::string_input(std::string_view("www.qq.com")); 31 | auto parse = lexy::parse(ipt, lexy_ext::report_error); 32 | ASSERT_TRUE(parse.has_value()); 33 | ASSERT_EQ("www.qq.com", parse.value()); 34 | } 35 | 36 | TEST(TDE, url) { 37 | auto url = ParseUrl("www.qq.com/?a=b&&c=d"); 38 | ASSERT_EQ(url.host_, "www.qq.com"); 39 | ASSERT_TRUE(url.param_.has_value()); 40 | ASSERT_EQ("a=b&&c=d", url.param_.value()); 41 | } 42 | 43 | TEST(TDE, bad_url) { 44 | ASSERT_ANY_THROW([] { ParseUrl("blablah!@@"); }()); 45 | } 46 | 47 | } // namespace tde::details::url_parser::rules 48 | -------------------------------------------------------------------------------- /examples/nvt_dataloader/README.md: -------------------------------------------------------------------------------- 1 | # Running torchrec using NVTabular DataLoader 2 | 3 | First run nvtabular preprocessing to first convert the criteo TSV files to parquet, and perform offline preprocessing. 4 | 5 | Please follow the installation instructions in the [README](https://github.com/pytorch/torchrec/tree/main/torchrec/datasets/scripts/nvt) of torchrec/torchrec/datasets/scripts/nvt. 6 | 7 | Afterward, to run the model across 8 GPUs, use the below command 8 | 9 | ``` 10 | torchx run -s local_cwd dist.ddp -j 1x8 --script train_torchrec.py -- --num_embeddings_per_feature 40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36 --over_arch_layer_sizes 1024,1024,512,256,1 --dense_arch_layer_sizes 512,256,128 --embedding_dim 128 --binary_path /criteo_binary/split/ --learning_rate 1.0 --validation_freq_within_epoch 1000000 --throughput_check_freq_within_epoch 1000000 --batch_size 256 11 | ``` 12 | 13 | To run with adagrad as an optimizer, use the below flag 14 | 15 | ``` 16 | ---adagrad 17 | ``` 18 | 19 | # Test on A100s 20 | 21 | ## Preliminary Training Results 22 | 23 | **Setup:** 24 | * Dataset: Criteo 1TB Click Logs dataset 25 | * CUDA 11.1, NCCL 2.10.3. 26 | * AWS p4d24xlarge instances, each with 8 40GB NVIDIA A100s. 27 | 28 | **Results** 29 | 30 | Reproducing MLPerfV1 settings 31 | 1. Embedding per features + model architecture 32 | 2. Learning Rate fixed at 1.0 with SGD 33 | 3. Dataset setup: 34 | - No frequency thresholding 35 | 4. Report > .8025 on validation set (0.8027645945549011 from above script) 36 | 5. Global batch size 2048 37 | -------------------------------------------------------------------------------- /examples/retrieval/data/dataloader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | from torch.utils.data import DataLoader 11 | from torchrec.datasets.movielens import DEFAULT_RATINGS_COLUMN_NAMES 12 | from torchrec.datasets.random import RandomRecDataset 13 | 14 | 15 | def get_dataloader( 16 | batch_size: int, num_embeddings: int, pin_memory: bool = False, num_workers: int = 0 17 | ) -> DataLoader: 18 | """ 19 | Gets a Random dataloader for the two tower model, containing a two_feature KJT as sparse_features, empty dense_features 20 | and binary labels 21 | 22 | Args: 23 | batch_size (int): batch_size 24 | num_embeddings (int): hash_size of the two embedding tables 25 | pin_memory (bool): Whether to pin_memory on the GPU 26 | num_workers (int) Number of dataloader workers 27 | 28 | Returns: 29 | dataloader (DataLoader): PyTorch dataloader for the specified options. 30 | 31 | """ 32 | two_tower_column_names = DEFAULT_RATINGS_COLUMN_NAMES[:2] 33 | 34 | return DataLoader( 35 | RandomRecDataset( 36 | keys=two_tower_column_names, 37 | batch_size=batch_size, 38 | hash_size=num_embeddings, 39 | ids_per_feature=1, 40 | num_dense=0, 41 | ), 42 | batch_size=None, 43 | batch_sampler=None, 44 | pin_memory=pin_memory, 45 | num_workers=num_workers, 46 | ) 47 | -------------------------------------------------------------------------------- /torchrec/datasets/scripts/tests/test_npy_preproc_criteo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import os 11 | import tempfile 12 | import unittest 13 | 14 | import numpy as np 15 | from torchrec.datasets.criteo import CAT_FEATURE_COUNT, INT_FEATURE_COUNT 16 | from torchrec.datasets.scripts.npy_preproc_criteo import main 17 | from torchrec.datasets.test_utils.criteo_test_utils import CriteoTest 18 | 19 | 20 | class MainTest(unittest.TestCase): 21 | def test_main(self) -> None: 22 | num_rows = 10 23 | name = "day_0" 24 | with CriteoTest._create_dataset_tsv( 25 | num_rows=num_rows, 26 | filename=name, 27 | ) as in_file_path, tempfile.TemporaryDirectory() as output_dir: 28 | main( 29 | [ 30 | "--input_dir", 31 | os.path.dirname(in_file_path), 32 | "--output_dir", 33 | output_dir, 34 | ] 35 | ) 36 | 37 | dense = np.load(os.path.join(output_dir, name + "_dense.npy")) 38 | sparse = np.load(os.path.join(output_dir, name + "_sparse.npy")) 39 | labels = np.load(os.path.join(output_dir, name + "_labels.npy")) 40 | 41 | self.assertEqual(dense.shape, (num_rows, INT_FEATURE_COUNT)) 42 | self.assertEqual(sparse.shape, (num_rows, CAT_FEATURE_COUNT)) 43 | self.assertEqual(labels.shape, (num_rows, 1)) 44 | -------------------------------------------------------------------------------- /torchrec/inference/inference_legacy/src/Validation.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "torchrec/inference/Validation.h" 10 | #include "ATen/Functions.h" 11 | 12 | namespace torchrec { 13 | 14 | bool validateSparseFeatures( 15 | at::Tensor& values, 16 | at::Tensor& lengths, 17 | std::optional maybeWeights) { 18 | auto flatLengths = lengths.view(-1); 19 | 20 | // validate sum of lengths equals number of values/weights 21 | auto lengthsTotal = at::sum(flatLengths).item(); 22 | if (lengthsTotal != values.size(0)) { 23 | return false; 24 | } 25 | if (maybeWeights.has_value() && lengthsTotal != maybeWeights->size(0)) { 26 | return false; 27 | } 28 | 29 | // Validate no negative values in lengths. 30 | // Use faster path if contiguous. 31 | if (flatLengths.is_contiguous()) { 32 | int* ptr = (int*)flatLengths.data_ptr(); 33 | for (int i = 0; i < flatLengths.numel(); ++i) { 34 | if (*ptr < 0) { 35 | return false; 36 | } 37 | ptr++; 38 | } 39 | } else { 40 | // accessor does boundary check (slower) 41 | auto acc = flatLengths.accessor(); 42 | for (int i = 0; i < acc.size(0); i++) { 43 | if (acc[i] < 0) { 44 | return false; 45 | } 46 | } 47 | } 48 | 49 | return true; 50 | } 51 | 52 | bool validateDenseFeatures(at::Tensor& values, size_t batchSize) { 53 | return values.size(0) % batchSize == 0; 54 | } 55 | 56 | } // namespace torchrec 57 | -------------------------------------------------------------------------------- /benchmarks/cpp/dynamic_embedding/mixed_lfu_lru_strategy_benchmark.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include 10 | #include 11 | 12 | namespace torchrec { 13 | void BM_MixedLFULRUStrategy(benchmark::State& state) { 14 | size_t num_ext_values = state.range(0); 15 | std::vector ext_values(num_ext_values); 16 | 17 | MixedLFULRUStrategy strategy; 18 | for (auto& v : ext_values) { 19 | v = strategy.update(0, 0, std::nullopt); 20 | } 21 | 22 | size_t num_elems = state.range(1); 23 | std::default_random_engine engine((std::random_device())()); 24 | size_t time = 0; 25 | for (auto _ : state) { 26 | state.PauseTiming(); 27 | std::vector offsets; 28 | offsets.reserve(num_elems); 29 | for (size_t i = 0; i < num_elems; ++i) { 30 | std::uniform_int_distribution dist(0, num_elems - 1); 31 | offsets.emplace_back(dist(engine)); 32 | } 33 | state.ResumeTiming(); 34 | 35 | ++time; 36 | strategy.update_time(time); 37 | for (auto& v : offsets) { 38 | ext_values[v] = strategy.update(0, 0, ext_values[v]); 39 | } 40 | } 41 | } 42 | 43 | BENCHMARK(BM_MixedLFULRUStrategy) 44 | ->ArgNames({"num_ext_values", "num_elems_per_iter"}) 45 | ->Args({30000000, 1024 * 1024}) 46 | ->Args({300000000, 1024 * 1024}) 47 | ->Unit(benchmark::kMillisecond) 48 | ->Iterations(100); 49 | 50 | } // namespace torchrec 51 | -------------------------------------------------------------------------------- /contrib/dynamic_embedding/src/tde/details/move_only_function.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | namespace tde::details { 7 | 8 | template 9 | class MoveOnlyFunction; 10 | 11 | template 12 | class MoveOnlyFunction { 13 | public: 14 | MoveOnlyFunction() = default; 15 | 16 | template 17 | /*implicit*/ MoveOnlyFunction(F f) : f_(new Derived(std::move(f))) {} 18 | 19 | MoveOnlyFunction(const MoveOnlyFunction&) = delete; 20 | MoveOnlyFunction(MoveOnlyFunction&& o) noexcept 21 | : f_(std::move(o.f_)) {} 22 | 23 | MoveOnlyFunction& operator=(const MoveOnlyFunction&) = 24 | delete; 25 | MoveOnlyFunction& operator=( 26 | MoveOnlyFunction&&) noexcept = default; 27 | 28 | template 29 | MoveOnlyFunction& operator=(F f) { 30 | (*this) = MoveOnlyFunction(std::move(f)); 31 | return *this; 32 | } 33 | 34 | R operator()(Args&&... args) { 35 | return (*f_)(std::forward(args)...); 36 | } 37 | 38 | /*implicit*/ operator bool() const { 39 | return f_ != nullptr; 40 | } 41 | 42 | private: 43 | struct Base { 44 | virtual ~Base() = default; 45 | virtual R operator()(Args&&...) = 0; 46 | }; 47 | 48 | template 49 | struct Derived final : public Base { 50 | explicit Derived(F f) : f_(std::move(f)) {} 51 | R operator()(Args&&... args) override { 52 | return f_(std::forward(args)...); 53 | } 54 | F f_; 55 | }; 56 | std::unique_ptr f_; 57 | }; 58 | 59 | } // namespace tde::details 60 | -------------------------------------------------------------------------------- /torchrec/modules/activation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | """ 11 | Activation Modules 12 | """ 13 | 14 | from typing import List, Optional, Union 15 | 16 | import torch 17 | from torch import nn 18 | 19 | 20 | class SwishLayerNorm(nn.Module): 21 | """ 22 | Applies the Swish function with layer normalization: `Y = X * Sigmoid(LayerNorm(X))`. 23 | 24 | Args: 25 | input_dims (Union[int, List[int], torch.Size]): dimensions to normalize over. 26 | If an input tensor has shape [batch_size, d1, d2, d3], setting 27 | input_dim=[d2, d3] will do the layer normalization on last two dimensions. 28 | device (Optional[torch.device]): default compute device. 29 | 30 | Example:: 31 | 32 | sln = SwishLayerNorm(100) 33 | """ 34 | 35 | def __init__( 36 | self, 37 | input_dims: Union[int, List[int], torch.Size], 38 | device: Optional[torch.device] = None, 39 | ) -> None: 40 | super().__init__() 41 | self.norm: torch.nn.modules.Sequential = nn.Sequential( 42 | nn.LayerNorm(input_dims, device=device), 43 | nn.Sigmoid(), 44 | ) 45 | 46 | def forward( 47 | self, 48 | input: torch.Tensor, 49 | ) -> torch.Tensor: 50 | """ 51 | Args: 52 | input (torch.Tensor): an input tensor. 53 | 54 | Returns: 55 | torch.Tensor: an output tensor. 56 | """ 57 | return input * self.norm(input) 58 | -------------------------------------------------------------------------------- /examples/ray/compute_world_size.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | #!/usr/bin/env python3 9 | 10 | import os 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch.distributed import all_reduce, get_rank, get_world_size, init_process_group 15 | 16 | 17 | def compute_world_size() -> int: 18 | "Dummy script to compute world_size. Meant to test if can run Ray + Pytorch DDP" 19 | rank = int(os.getenv("RANK")) # pyre-ignore[6] 20 | world_size = int(os.getenv("WORLD_SIZE")) # pyre-ignore[6] 21 | master_port = int(os.getenv("MASTER_PORT")) # pyre-ignore[6] 22 | master_addr = os.getenv("MASTER_ADDR") 23 | backend = "gloo" 24 | 25 | print(f"initializing `{backend}` process group") 26 | init_process_group( # pyre-ignore[16] 27 | backend=backend, 28 | init_method=f"tcp://{master_addr}:{master_port}", 29 | rank=rank, 30 | world_size=world_size, 31 | ) 32 | print("successfully initialized process group") 33 | 34 | rank = get_rank() # pyre-ignore[16] 35 | world_size = get_world_size() # pyre-ignore[16] 36 | 37 | t = F.one_hot(torch.tensor(rank), num_classes=world_size) 38 | all_reduce(t) # pyre-ignore[16] 39 | computed_world_size = int(torch.sum(t).item()) 40 | print( 41 | f"rank: {rank}, actual world_size: {world_size}, computed world_size: {computed_world_size}" 42 | ) 43 | return computed_world_size 44 | 45 | 46 | def main() -> None: 47 | compute_world_size() 48 | 49 | 50 | if __name__ == "__main__": 51 | main() 52 | -------------------------------------------------------------------------------- /examples/ray/README.md: -------------------------------------------------------------------------------- 1 | # Running torchrec with torchx using Ray scheduler on a Ray cluster 2 | 3 | ``` 4 | pip install --pre torchrec -f https://download.pytorch.org/whl/torchrec/index.html 5 | pip install torchx-nightly 6 | pip install "ray[default]" -qqq 7 | ``` 8 | 9 | Run torchx with the dashboard address and a link to your component 10 | ``` 11 | torchx run -s ray -cfg dashboard_address=localhost:6379,working_dir=~/repos/torchrec/examples/ray,requirements=./requirements.txt dist.ddp -j 1x2 --script ~/repos/torchrec/examples/ray/train_torchrec.py 12 | ``` 13 | 14 | Or run locally 15 | ``` 16 | torchx run -s ray -cfg working_dir=~/repos/torchrec/examples/ray,requirements=./requirements.txt dist.ddp -j 1x2 --script ~/repos/torchrec/examples/ray/train_torchrec.py 17 | ``` 18 | 19 | To run with GPUs, add --gpus as a flag to the dist.ddp component 20 | For other dist.ddp options, see https://pytorch.org/torchx/latest/components/distributed.html 21 | ``` 22 | torchx run -s ray -cfg working_dir=~/repos/torchrec/examples/ray,requirements=./requirements.txt dist.ddp -j 1x2 --gpu 2 --script ~/repos/torchrec/examples/ray/train_torchrec.py 23 | ``` 24 | 25 | To run w/o ray scheduler (only torchx) 26 | For available settings https://pytorch.org/torchx/latest/cli.html?highlight=torchx%20run 27 | ``` 28 | torchx run -s local_cwd dist.ddp -j 1x2 --script ~/repos/torchrec/examples/ray/train_torchrec.py 29 | ``` 30 | 31 | Job ID looks like ray://torchx/172.31.16.248:6379-raysubmit_ntquG1dDV6CtFUC5 32 | Replace the job ID below by your string 33 | 34 | 35 | Get a job status 36 | PENDING, FAILED, INTERRUPTED ETC.. 37 | ``` 38 | torchx status ray://torchx/172.31.16.248:6379-raysubmit_ntquG1dDV6CtFUC5 39 | ``` 40 | 41 | Get logs 42 | ``` 43 | torchx log ray://torchx/172.31.16.248:6379-raysubmit_ntquG1dDV6CtFUC5/worker/0 44 | ``` 45 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to torchrec 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. You can use [lintrunner](https://github.com/pytorch/pytorch/wiki/lintrunner) to do so. 13 | 1. To set up: 14 | ``` 15 | pip install lintrunner 16 | lintrunner init 17 | ``` 18 | 2. To lint your local changes: 19 | ``` 20 | lintrunner 21 | ``` 22 | 3. To format locally changed files: 23 | ``` 24 | lintrunner f 25 | ``` 26 | 4. To lint all files: 27 | ``` 28 | lintrunner --all-files 29 | ``` 30 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 31 | 32 | ## Contributor License Agreement ("CLA") 33 | In order to accept your pull request, we need you to submit a CLA. You only need 34 | to do this once to work on any of Facebook's open source projects. 35 | 36 | Complete your CLA here: 37 | 38 | ## Issues 39 | We use GitHub issues to track public bugs. Please ensure your description is 40 | clear and has sufficient instructions to be able to reproduce the issue. 41 | 42 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 43 | disclosure of security bugs. In those cases, please go through the process 44 | outlined on that page and do not file a public issue. 45 | 46 | ## License 47 | By contributing to torchrec, you agree that your contributions will be licensed 48 | under the LICENSE file in the root directory of this source tree. 49 | -------------------------------------------------------------------------------- /torchrec/inference/include/torchrec/inference/Exception.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | #include 11 | 12 | namespace torchrec { 13 | 14 | // We have different error code defined for different kinds of exceptions in 15 | // fblearner/sigrid predictor. (Code pointer: 16 | // fblearner/predictor/if/prediction_service.thrift.) We define different 17 | // exception type here so that in fblearner/sigrid predictor we can detect the 18 | // exception type and return the corresponding error code to reflect the right 19 | // info. 20 | class TorchrecException : public std::runtime_error { 21 | public: 22 | explicit TorchrecException(const std::string& error) 23 | : std::runtime_error(error) {} 24 | }; 25 | 26 | // GPUOverloadException maps to 27 | // PredictionExceptionCode::GPU_BATCHING_QUEUE_TIMEOUT 28 | class GPUOverloadException : public TorchrecException { 29 | public: 30 | explicit GPUOverloadException(const std::string& error) 31 | : TorchrecException(error) {} 32 | }; 33 | 34 | // GPUExecutorOverloadException maps to 35 | // PredictionExceptionCode::GPU_EXECUTOR_QUEUE_TIMEOUT 36 | class GPUExecutorOverloadException : public TorchrecException { 37 | public: 38 | explicit GPUExecutorOverloadException(const std::string& error) 39 | : TorchrecException(error) {} 40 | }; 41 | 42 | // TorchDeployException maps to 43 | // PredictorUserErrorCode::TORCH_DEPLOY_ERROR 44 | class TorchDeployException : public TorchrecException { 45 | public: 46 | explicit TorchDeployException(const std::string& error) 47 | : TorchrecException(error) {} 48 | }; 49 | } // namespace torchrec 50 | -------------------------------------------------------------------------------- /contrib/dynamic_embedding/src/tde/details/id_transformer_variant_impl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace tde::details { 4 | 5 | template 6 | inline bool IDTransformer::Transform( 7 | tcb::span global_ids, 8 | tcb::span cache_ids, 9 | Fetch fetch) { 10 | return strategy_.VisitUpdator([&](auto&& update) -> bool { 11 | return std::visit( 12 | [&](auto&& transformer) -> bool { 13 | return transformer.Transform( 14 | global_ids, 15 | cache_ids, 16 | std::forward(update), 17 | std::move(fetch)); 18 | }, 19 | var_); 20 | }); 21 | } 22 | 23 | template 24 | inline auto IDTransformer::LXUStrategy::VisitUpdator(Visitor visit) 25 | -> std::invoke_result_t { 26 | return std::visit( 27 | [&](auto& s) { 28 | using T = typename std::decay_t::lxu_record_t; 29 | auto update = 30 | [&](std::optional record, int64_t global_id, int64_t cache_id) { 31 | return s.Update(global_id, cache_id, record); 32 | }; 33 | return visit(update); 34 | }, 35 | strategy_); 36 | } 37 | 38 | template 39 | int64_t IDTransformer::LXUStrategy::Time(T record) { 40 | return std::visit( 41 | [&](auto& s) -> int64_t { return s.Time(record); }, strategy_); 42 | } 43 | 44 | template 45 | inline std::vector IDTransformer::LXUStrategy::Evict( 46 | Iterator iterator, 47 | uint64_t num_to_evict) { 48 | return std::visit( 49 | [&, iterator = std::move(iterator)](auto& s) mutable { 50 | return s.Evict(std::move(iterator), num_to_evict); 51 | }, 52 | strategy_); 53 | } 54 | 55 | } // namespace tde::details 56 | -------------------------------------------------------------------------------- /torchrec/csrc/dynamic_embedding/details/bitmap_impl.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | #include 11 | #include 12 | 13 | namespace torchrec { 14 | 15 | template 16 | inline Bitmap::Bitmap(int64_t num_bits) 17 | : num_total_bits_(num_bits), 18 | num_values_((num_bits + num_bits_per_value - 1) / num_bits_per_value), 19 | values_(new T[num_values_]), 20 | next_free_bit_(0) { 21 | std::fill(values_.get(), values_.get() + num_values_, -1); 22 | } 23 | 24 | template 25 | inline int64_t Bitmap::next_free_bit() { 26 | int64_t result = next_free_bit_; 27 | int64_t offset = result / num_bits_per_value; 28 | T value = values_[offset]; 29 | // set the last 1 bit to zero 30 | values_[offset] = value & (value - 1); 31 | while (values_[offset] == 0 && offset < num_values_) { 32 | offset++; 33 | } 34 | value = values_[offset]; 35 | if (C10_LIKELY(value)) { 36 | next_free_bit_ = offset * num_bits_per_value + ctz(value); 37 | } else { 38 | next_free_bit_ = num_total_bits_; 39 | } 40 | 41 | return result; 42 | } 43 | 44 | template 45 | inline void Bitmap::free_bit(int64_t offset) { 46 | int64_t mask_offset = offset / num_bits_per_value; 47 | int64_t bit_offset = offset % num_bits_per_value; 48 | values_[mask_offset] |= 1 << bit_offset; 49 | next_free_bit_ = std::min(offset, next_free_bit_); 50 | } 51 | template 52 | inline bool Bitmap::full() const { 53 | return next_free_bit_ >= num_total_bits_; 54 | } 55 | 56 | } // namespace torchrec 57 | -------------------------------------------------------------------------------- /torchrec/inference/inference_legacy/tests/predict_module_tests.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import unittest 9 | from typing import Dict 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torchrec.inference.modules import PredictModule, quantize_dense 14 | 15 | 16 | class TestModule(nn.Module): 17 | def __init__(self) -> None: 18 | super().__init__() 19 | self.linear0 = nn.Linear(10, 1) 20 | self.linear1 = nn.Linear(1, 1) 21 | 22 | 23 | class TestPredictModule(PredictModule): 24 | def predict_forward( 25 | self, batch: Dict[str, torch.Tensor] 26 | ) -> Dict[str, torch.Tensor]: 27 | return self.predict_module(*batch) 28 | 29 | 30 | class PredictModulesTest(unittest.TestCase): 31 | def test_predict_module(self) -> None: 32 | module = TestModule() 33 | predict_module = TestPredictModule(module) 34 | 35 | module_state_dict = module.state_dict() 36 | predict_module_state_dict = predict_module.state_dict() 37 | 38 | self.assertEqual(module_state_dict.keys(), predict_module_state_dict.keys()) 39 | 40 | for tensor0, tensor1 in zip( 41 | module_state_dict.values(), predict_module_state_dict.values() 42 | ): 43 | self.assertTrue(torch.equal(tensor0, tensor1)) 44 | 45 | def test_dense_lowering(self) -> None: 46 | module = TestModule() 47 | predict_module = TestPredictModule(module) 48 | predict_module = quantize_dense(predict_module, torch.half) 49 | for param in predict_module.parameters(): 50 | self.assertEqual(param.dtype, torch.half) 51 | -------------------------------------------------------------------------------- /torchrec/streamable.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import abc 11 | 12 | import torch 13 | 14 | 15 | class Multistreamable(abc.ABC): 16 | """ 17 | Objects implementing this interface are allowed to be transferred 18 | from one CUDA stream to another. 19 | torch.Tensor and (Keyed)JaggedTensor implement this interface. 20 | """ 21 | 22 | @abc.abstractmethod 23 | def record_stream(self, stream: torch.Stream) -> None: 24 | """ 25 | See https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html 26 | """ 27 | ... 28 | 29 | 30 | class Pipelineable(Multistreamable): 31 | """ 32 | This interface contains two methods, one for moving an input across devices, 33 | the other one for marking streams that operate the input. 34 | 35 | torch.Tensor implements this interface and we can used it in many applications. 36 | Another example is torchrec.(Keyed)JaggedTensor, which we use as the input to 37 | torchrec.EmbeddingBagCollection, which in turn is often the first layer of many models. 38 | Some models take compound inputs, which should implement this interface. 39 | """ 40 | 41 | @abc.abstractmethod 42 | def to(self, device: torch.device, non_blocking: bool) -> "Pipelineable": 43 | """ 44 | Please be aware that according to https://pytorch.org/docs/stable/generated/torch.Tensor.to.html, 45 | `to` might return self or a copy of self. So please remember to use `to` with the assignment operator, 46 | for example, `in = in.to(new_device)`. 47 | """ 48 | ... 49 | -------------------------------------------------------------------------------- /torchrec/csrc/dynamic_embedding/bind.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | namespace torchrec { 16 | TORCH_LIBRARY(tde, m) { 17 | m.def("register_io", [](const std::string& name) { 18 | IORegistry::Instance().register_plugin(name.c_str()); 19 | }); 20 | 21 | m.class_("TransformResult") 22 | .def_readonly("success", &TransformResult::success) 23 | .def_readonly("ids_to_fetch", &TransformResult::ids_to_fetch); 24 | 25 | m.class_("IDTransformer") 26 | .def(torch::init()) 27 | .def("transform", &IDTransformerWrapper::transform) 28 | .def("evict", &IDTransformerWrapper::evict) 29 | .def("save", &IDTransformerWrapper::save); 30 | 31 | m.class_("LocalShardList") 32 | .def(torch::init([]() { return c10::make_intrusive(); })) 33 | .def("append", &LocalShardList::emplace_back); 34 | 35 | m.class_("FetchHandle").def("wait", &FetchHandle::wait); 36 | 37 | m.class_("PS") 38 | .def( 39 | torch::init< 40 | std::string, 41 | c10::intrusive_ptr, 42 | int64_t, 43 | int64_t, 44 | std::string, 45 | int64_t>()) 46 | .def("fetch", &PS::fetch) 47 | .def("evict", &PS::evict); 48 | } 49 | } // namespace torchrec 50 | -------------------------------------------------------------------------------- /torchrec/csrc/dynamic_embedding/id_transformer_wrapper.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | namespace torchrec { 16 | 17 | struct TransformResult : public torch::CustomClassHolder { 18 | TransformResult(bool success, torch::Tensor ids_to_fetch) 19 | : success(success), ids_to_fetch(ids_to_fetch) {} 20 | 21 | // Whether the fetch succeeded (if evicted is not necessary) 22 | bool success; 23 | // new ids to fetch from PS. 24 | // shape of [num_to_fetch, 2], where each row is consist of 25 | // the global id and cache id of each ID. 26 | torch::Tensor ids_to_fetch; 27 | }; 28 | 29 | class IDTransformerWrapper : public torch::CustomClassHolder { 30 | public: 31 | IDTransformerWrapper( 32 | int64_t num_embedding, 33 | const std::string& id_transformer_type, 34 | const std::string& lxu_strategy_type, 35 | int64_t min_used_freq_power = 5); 36 | 37 | c10::intrusive_ptr transform( 38 | std::vector global_ids, 39 | std::vector cache_ids, 40 | int64_t time); 41 | torch::Tensor evict(int64_t num_to_evict); 42 | torch::Tensor save(); 43 | 44 | private: 45 | std::mutex mu_; 46 | std::unique_ptr transformer_; 47 | std::unique_ptr strategy_; 48 | std::vector ids_to_fetch_; 49 | int64_t time_; 50 | int64_t last_save_time_; 51 | }; 52 | 53 | } // namespace torchrec 54 | -------------------------------------------------------------------------------- /torchrec/inference/inference_legacy/include/torchrec/inference/Exception.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | #include 11 | 12 | namespace torchrec { 13 | 14 | // We have different error code defined for different kinds of exceptions in 15 | // fblearner/sigrid predictor. (Code pointer: 16 | // fblearner/predictor/if/prediction_service.thrift.) We define different 17 | // exception type here so that in fblearner/sigrid predictor we can detect the 18 | // exception type and return the corresponding error code to reflect the right 19 | // info. 20 | class TorchrecException : public std::runtime_error { 21 | public: 22 | explicit TorchrecException(const std::string& error) 23 | : std::runtime_error(error) {} 24 | }; 25 | 26 | // GPUOverloadException maps to 27 | // PredictionExceptionCode::GPU_BATCHING_QUEUE_TIMEOUT 28 | class GPUOverloadException : public TorchrecException { 29 | public: 30 | explicit GPUOverloadException(const std::string& error) 31 | : TorchrecException(error) {} 32 | }; 33 | 34 | // GPUExecutorOverloadException maps to 35 | // PredictionExceptionCode::GPU_EXECUTOR_QUEUE_TIMEOUT 36 | class GPUExecutorOverloadException : public TorchrecException { 37 | public: 38 | explicit GPUExecutorOverloadException(const std::string& error) 39 | : TorchrecException(error) {} 40 | }; 41 | 42 | // TorchDeployException maps to 43 | // PredictorUserErrorCode::TORCH_DEPLOY_ERROR 44 | class TorchDeployException : public TorchrecException { 45 | public: 46 | explicit TorchDeployException(const std::string& error) 47 | : TorchrecException(error) {} 48 | }; 49 | } // namespace torchrec 50 | -------------------------------------------------------------------------------- /torchrec/optim/tests/test_optim.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import unittest 11 | 12 | import torch 13 | from torch.distributed.optim import ( 14 | _apply_optimizer_in_backward as apply_optimizer_in_backward, 15 | ) 16 | from torchrec.modules.embedding_configs import EmbeddingBagConfig 17 | from torchrec.modules.embedding_modules import EmbeddingBagCollection 18 | from torchrec.optim.optimizers import in_backward_optimizer_filter 19 | 20 | 21 | class TestInBackwardOptimizerFilter(unittest.TestCase): 22 | def test_in_backward_optimizer_filter(self) -> None: 23 | ebc = EmbeddingBagCollection( 24 | tables=[ 25 | EmbeddingBagConfig( 26 | name="t1", embedding_dim=4, num_embeddings=2, feature_names=["f1"] 27 | ), 28 | EmbeddingBagConfig( 29 | name="t2", embedding_dim=4, num_embeddings=2, feature_names=["f2"] 30 | ), 31 | ] 32 | ) 33 | apply_optimizer_in_backward( 34 | torch.optim.SGD, 35 | ebc.embedding_bags["t1"].parameters(), 36 | optimizer_kwargs={"lr": 1.0}, 37 | ) 38 | in_backward_params = dict( 39 | in_backward_optimizer_filter(ebc.named_parameters(), include=True) 40 | ) 41 | non_in_backward_params = dict( 42 | in_backward_optimizer_filter(ebc.named_parameters(), include=False) 43 | ) 44 | assert set(in_backward_params.keys()) == {"embedding_bags.t1.weight"} 45 | assert set(non_in_backward_params.keys()) == {"embedding_bags.t2.weight"} 46 | -------------------------------------------------------------------------------- /torchrec/csrc/dynamic_embedding/details/naive_id_transformer.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | namespace torchrec { 18 | 19 | /** 20 | * NaiveIDTransformer 21 | * 22 | * transform GlobalID to CacheID by naive flat hash map 23 | * @tparam LXURecord The extension type used for eviction strategy. 24 | * @tparam Bitmap The bitmap class to record the free cache ids. 25 | */ 26 | template > 27 | class NaiveIDTransformer : public IDTransformer { 28 | public: 29 | explicit NaiveIDTransformer(int64_t num_embedding); 30 | NaiveIDTransformer(const NaiveIDTransformer&) = delete; 31 | NaiveIDTransformer(NaiveIDTransformer&&) noexcept = default; 32 | 33 | bool transform( 34 | std::span global_ids, 35 | std::span cache_ids, 36 | update_t update = transform_default::no_update, 37 | fetch_t fetch = transform_default::no_fetch) override; 38 | 39 | void evict(std::span global_ids) override; 40 | 41 | iterator_t iterator() const override; 42 | 43 | private: 44 | struct CacheValue { 45 | int64_t cache_id; 46 | lxu_record_t lxu_record; 47 | }; 48 | 49 | ska::flat_hash_map global_id2cache_value_; 50 | Bitmap bitmap_; 51 | }; 52 | 53 | } // namespace torchrec 54 | 55 | #include 56 | -------------------------------------------------------------------------------- /torchrec/types.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | from abc import abstractmethod 11 | from enum import Enum, unique 12 | 13 | import torch 14 | from torch import nn 15 | 16 | 17 | class CacheMixin: 18 | """ 19 | A mixin to allow modules that cache computation to clear the cache. 20 | """ 21 | 22 | @abstractmethod 23 | def clear_cache(self) -> None: ... 24 | 25 | 26 | class CopyMixIn: 27 | @abstractmethod 28 | def copy(self, device: torch.device) -> nn.Module: ... 29 | 30 | 31 | class ModuleCopyMixin(CopyMixIn): 32 | """ 33 | A mixin to allow modules to override copy behaviors in DMP. 34 | """ 35 | 36 | def copy(self, device: torch.device) -> nn.Module: 37 | # pyre-ignore [16] 38 | return self.to(device) 39 | 40 | 41 | class ModuleNoCopyMixin(CopyMixIn): 42 | """ 43 | A mixin to allow modules to override copy behaviors in DMP. 44 | """ 45 | 46 | def copy(self, device: torch.device) -> nn.Module: 47 | # pyre-ignore [7] 48 | return self 49 | 50 | 51 | # moved DataType here to avoid circular import 52 | # TODO: organize types and dependencies 53 | @unique 54 | class DataType(Enum): 55 | """ 56 | Our fusion implementation supports only certain types of data 57 | so it makes sense to retrict in a non-fused version as well. 58 | """ 59 | 60 | FP32 = "FP32" 61 | FP16 = "FP16" 62 | BF16 = "BF16" 63 | INT64 = "INT64" 64 | INT32 = "INT32" 65 | INT8 = "INT8" 66 | UINT8 = "UINT8" 67 | INT4 = "INT4" 68 | INT2 = "INT2" 69 | NFP8 = "NFP8" 70 | 71 | def __str__(self) -> str: 72 | return self.value 73 | -------------------------------------------------------------------------------- /contrib/dynamic_embedding/src/tde/details/random_bits_generator.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include "tcb/span.hpp" 6 | 7 | namespace tde::details { 8 | 9 | /** 10 | * BitScanner holds n uint64_t values as bit stream. 11 | * And It can detect next n bits are all zero or not. 12 | * If n_bits is larger than remaining, the return 13 | * n_bits will be as remainder. 14 | */ 15 | class BitScanner { 16 | public: 17 | explicit BitScanner(size_t n); 18 | BitScanner(const BitScanner&) = delete; 19 | BitScanner(BitScanner&&) noexcept = default; 20 | 21 | /** 22 | * Reset array data 23 | * @tparam Callback (tcb::span) -> void 24 | * @param callback 25 | */ 26 | template 27 | void ResetArray(Callback callback) { 28 | callback(tcb::span(array.get(), size_)); 29 | array_idx_ = 0; 30 | bit_idx = 0; 31 | } 32 | 33 | bool IsNextNBitsAllZero(uint16_t& n_bits); 34 | 35 | // used by unittest only 36 | uint16_t array_idx_{0}; 37 | uint16_t bit_idx{0}; 38 | 39 | private: 40 | std::unique_ptr array; 41 | uint16_t size_; 42 | 43 | // if bit_idx > 64, incr array_idx 44 | void CouldCarryBitIndexToArrayIndex(); 45 | }; 46 | 47 | class RandomBitsGenerator { 48 | public: 49 | RandomBitsGenerator(); 50 | ~RandomBitsGenerator(); 51 | RandomBitsGenerator(const RandomBitsGenerator&) = delete; 52 | RandomBitsGenerator(RandomBitsGenerator&&) noexcept = default; 53 | 54 | /** 55 | * Is next N random bits are all zero or not. 56 | * i.e., the true prob is approximately 1/(2^n_bits). 57 | * 58 | * @param n_bits 59 | * @return 60 | */ 61 | bool IsNextNBitsAllZero(uint16_t n_bits); 62 | 63 | private: 64 | BitScanner scanner_; 65 | std::mt19937_64 engine_; 66 | void ResetScanner(); 67 | }; 68 | 69 | } // namespace tde::details 70 | -------------------------------------------------------------------------------- /contrib/dynamic_embedding/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | import os 10 | import sys 11 | 12 | import torch 13 | from setuptools import find_packages 14 | 15 | from skbuild import setup 16 | 17 | extra_cmake_args = [] 18 | 19 | if sys.platform == "linux": 20 | _nvcc_paths = ( 21 | [] 22 | if os.getenv("CMAKE_CUDA_COMPILER") is None 23 | else [os.getenv("CMAKE_CUDA_COMPILER")] 24 | ) + [ 25 | "/usr/bin/nvcc", 26 | "/usr/local/bin/nvcc", 27 | "/usr/local/cuda/bin/nvcc", 28 | "/usr/cuda/bin/nvcc", 29 | ] 30 | for _nvcc_path in _nvcc_paths: 31 | try: 32 | os.stat(_nvcc_path) 33 | extra_cmake_args.append(f"-DCMAKE_CUDA_COMPILER={_nvcc_path}") 34 | break 35 | except FileNotFoundError: 36 | pass 37 | else: 38 | raise RuntimeError(f"Cannot find nvcc in [{','.join(_nvcc_paths)}]") 39 | 40 | if os.getenv("CUDA_TOOLKIT_ROOT_DIR") is None: 41 | extra_cmake_args.append( 42 | f'-DCUDA_TOOLKIT_ROOT_DIR={os.path.abspath(os.path.join(os.path.dirname(_nvcc_path), ".."))}' 43 | ) 44 | else: 45 | extra_cmake_args.append( 46 | f"-DCUDA_TOOLKIT_ROOT_DIR={os.getenv('CUDA_TOOLKIT_ROOT_DIR')}" 47 | ) 48 | 49 | setup( 50 | name="torchrec_dynamic_embedding", 51 | package_dir={"": "src"}, 52 | packages=find_packages("src"), 53 | cmake_args=[ 54 | "-DCMAKE_BUILD_TYPE=Release", 55 | f"-DTDE_TORCH_BASE_DIR={os.path.dirname(torch.__file__)}", 56 | "-DTDE_WITH_TESTING=OFF", 57 | ] 58 | + extra_cmake_args, 59 | cmake_install_dir="src", 60 | version="0.0.1", 61 | ) 62 | -------------------------------------------------------------------------------- /torchrec/fx/tests/test_tracer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import unittest 11 | 12 | import torch 13 | import torch.fx 14 | from torch.testing import FileCheck # @manual 15 | from torchrec.distributed.types import LazyAwaitable 16 | from torchrec.fx import symbolic_trace 17 | 18 | 19 | class TestTracer(unittest.TestCase): 20 | def test_trace_async_module(self) -> None: 21 | class NeedWait(LazyAwaitable[torch.Tensor]): 22 | def __init__(self, obj: torch.Tensor) -> None: 23 | super().__init__() 24 | self._obj = obj 25 | 26 | def _wait_impl(self) -> torch.Tensor: 27 | return self._obj + 3 28 | 29 | class MyAsyncModule(torch.nn.Module): 30 | def __init__(self): 31 | super().__init__() 32 | 33 | def forward(self, input) -> LazyAwaitable[torch.Tensor]: 34 | return NeedWait(input + 2) 35 | 36 | # Test automated LazyAwaitable type `wait()` 37 | class AutoModel(torch.nn.Module): 38 | def __init__(self) -> None: 39 | super().__init__() 40 | self.sparse = MyAsyncModule() 41 | 42 | def forward(self, input: torch.Tensor) -> torch.Tensor: 43 | return torch.add(self.sparse(input), input * 10) 44 | 45 | auto_model = AutoModel() 46 | auto_gm = symbolic_trace(auto_model) 47 | FileCheck().check("+ 2").check("NeedWait").check("* 10").run(auto_gm.code) 48 | 49 | input = torch.randn(3, 4) 50 | ref_out = auto_model(input) 51 | traced_out = auto_gm(input) 52 | self.assertTrue(torch.equal(ref_out, traced_out)) 53 | -------------------------------------------------------------------------------- /torchrec/optim/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | """Torchrec Optimizers 11 | 12 | Torchrec contains a special optimizer called KeyedOptimizer. KeyedOptimizer exposes the state_dict with meaningful keys- it enables loading both 13 | torch.tensor and `ShardedTensor `_ in place, and it prohibits loading an empty state into already initialized KeyedOptimizer and vise versa. 14 | 15 | It also contains 16 | - several modules wrapping KeyedOptimizer, called CombinedOptimizer and OptimizerWrapper 17 | - Optimizers used in RecSys: e.g. rowwise adagrad/adam/etc 18 | """ 19 | from torchrec.optim.apply_optimizer_in_backward import ( # noqa 20 | apply_optimizer_in_backward, 21 | ) 22 | 23 | from torchrec.optim.clipping import GradientClipping, GradientClippingOptimizer # noqa 24 | from torchrec.optim.fused import FusedOptimizer, FusedOptimizerModule # noqa 25 | from torchrec.optim.keyed import ( # noqa 26 | CombinedOptimizer, 27 | KeyedOptimizer, 28 | KeyedOptimizerWrapper, 29 | OptimizerWrapper, 30 | ) 31 | from torchrec.optim.optimizers import ( # noqa 32 | Adagrad, 33 | Adam, 34 | LAMB, 35 | LarsSGD, 36 | PartialRowWiseAdam, 37 | PartialRowWiseLAMB, 38 | SGD, 39 | ) 40 | from torchrec.optim.rowwise_adagrad import RowWiseAdagrad # noqa 41 | from torchrec.optim.semi_sync import SemisyncOptimizer # noqa 42 | from torchrec.optim.warmup import WarmupOptimizer, WarmupPolicy, WarmupStage # noqa 43 | 44 | from . import ( # noqa # noqa # noqa # noqa 45 | apply_optimizer_in_backward, 46 | clipping, 47 | fused, 48 | keyed, 49 | optimizers, 50 | rowwise_adagrad, 51 | warmup, 52 | ) 53 | -------------------------------------------------------------------------------- /contrib/dynamic_embedding/src/tde/bind.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "tde/id_transformer.h" 4 | #include "tde/ps.h" 5 | 6 | namespace tde { 7 | TORCH_LIBRARY(tde, m) { 8 | details::IORegistry::RegisterAllDefaultIOs(); 9 | 10 | m.def("register_io", [](const std::string& name) { 11 | details::IORegistry::Instance().RegisterPlugin(name.c_str()); 12 | }); 13 | 14 | m.class_("TransformResult") 15 | .def_readonly("success", &TransformResult::success_) 16 | .def_readonly("ids_to_fetch", &TransformResult::ids_to_fetch_); 17 | 18 | m.class_("TensorList") 19 | .def(torch::init([]() { return c10::make_intrusive(); })) 20 | .def("append", &TensorList::push_back) 21 | .def("__len__", &TensorList::size) 22 | .def("__getitem__", &TensorList::operator[]); 23 | 24 | m.class_("IDTransformer") 25 | .def(torch::init([](int64_t num_embedding, const std::string& config) { 26 | nlohmann::json json = nlohmann::json::parse(config); 27 | return c10::make_intrusive( 28 | num_embedding, std::move(json)); 29 | })) 30 | .def("transform", &IDTransformer::Transform) 31 | .def("evict", &IDTransformer::Evict) 32 | .def("save", &IDTransformer::Save); 33 | 34 | m.class_("LocalShardList") 35 | .def(torch::init([]() { return c10::make_intrusive(); })) 36 | .def("append", &LocalShardList::emplace_back); 37 | 38 | m.class_("FetchHandle").def("wait", &FetchHandle::Wait); 39 | 40 | m.class_("PS") 41 | .def( 42 | torch::init< 43 | std::string, 44 | c10::intrusive_ptr, 45 | int64_t, 46 | int64_t, 47 | std::string, 48 | int64_t>()) 49 | .def("fetch", &PS::Fetch) 50 | .def("evict", &PS::Evict); 51 | } 52 | } // namespace tde 53 | -------------------------------------------------------------------------------- /contrib/dynamic_embedding/src/torchrec_dynamic_embedding/id_transformer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import json 9 | import os 10 | 11 | import torch 12 | 13 | from .tensor_list import TensorList 14 | 15 | try: 16 | torch.ops.load_library(os.path.join(os.path.dirname(__file__), "tde_cpp.so")) 17 | except Exception as ex: 18 | print(f"File tde_cpp.so not found {ex}") 19 | 20 | 21 | __all__ = [] 22 | 23 | 24 | class IDTransformer: 25 | def __init__(self, num_embedding, eviction_config=None, transform_config=None): 26 | self._num_embedding = num_embedding 27 | if not eviction_config: 28 | eviction_config = {"type": "mixed_lru_lfu"} 29 | if not transform_config: 30 | transform_config = {"type": "naive"} 31 | config = json.dumps( 32 | { 33 | "lxu_strategy": eviction_config, 34 | "id_transformer": transform_config, 35 | } 36 | ) 37 | self._transformer = torch.classes.tde.IDTransformer(num_embedding, config) 38 | 39 | def transform(self, global_ids: TensorList, cache_ids: TensorList, time: int): 40 | """ 41 | Transform `global_ids` and store the results in `cache_ids`. 42 | """ 43 | result = self._transformer.transform( 44 | global_ids.tensor_list, cache_ids.tensor_list, time 45 | ) 46 | return result.success, result.ids_to_fetch 47 | 48 | def evict(self, num_to_evict): 49 | """ 50 | Evict `num_to_evict` ids from the transformer. 51 | """ 52 | return self._transformer.evict(num_to_evict) 53 | 54 | def save(self): 55 | """ 56 | Get ids to save. 57 | """ 58 | return self._transformer.save() 59 | -------------------------------------------------------------------------------- /torchrec/inference/include/torchrec/inference/ResultSplit.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #include 12 | #include 13 | 14 | namespace torchrec { 15 | 16 | class ResultSplitFunc { 17 | public: 18 | virtual ~ResultSplitFunc() = default; 19 | 20 | virtual std::string name() = 0; 21 | 22 | virtual c10::IValue splitResult( 23 | c10::IValue /* result */, 24 | size_t /* nOffset */, 25 | size_t /* nLength */, 26 | size_t /* nTotalLength */) = 0; 27 | 28 | virtual c10::IValue moveToHost(c10::IValue /* result */) = 0; 29 | }; 30 | 31 | /** 32 | * TorchRecResultSplitFuncRegistry is used to register custom result split 33 | * functions. 34 | */ 35 | C10_DECLARE_REGISTRY(TorchRecResultSplitFuncRegistry, ResultSplitFunc); 36 | 37 | #define REGISTER_TORCHREC_RESULTSPLIT_FUNC(name, ...) \ 38 | C10_REGISTER_CLASS(TorchRecResultSplitFuncRegistry, name, __VA_ARGS__); 39 | 40 | c10::IValue splitDictOfTensor( 41 | c10::IValue result, 42 | size_t nOffset, 43 | size_t nLength, 44 | size_t nTotalLength); 45 | 46 | c10::IValue splitDictOfTensors( 47 | c10::IValue result, 48 | size_t nOffset, 49 | size_t nLength, 50 | size_t nTotalLength); 51 | 52 | c10::IValue 53 | splitDictWithMaskTensor(c10::IValue result, size_t nOffset, size_t nLength); 54 | 55 | class DictWithMaskTensorResultSplitFunc : public torchrec::ResultSplitFunc { 56 | public: 57 | virtual std::string name() override; 58 | 59 | virtual c10::IValue splitResult( 60 | c10::IValue result, 61 | size_t offset, 62 | size_t length, 63 | size_t /* nTotalLength */) override; 64 | 65 | c10::IValue moveToHost(c10::IValue result) override; 66 | }; 67 | 68 | } // namespace torchrec 69 | -------------------------------------------------------------------------------- /torchrec/distributed/planner/perf_models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | from typing import cast, List 11 | 12 | from torchrec.distributed.planner.types import ( 13 | Perf, 14 | PerfModel, 15 | ShardingOption, 16 | Storage, 17 | Topology, 18 | ) 19 | 20 | 21 | class NoopPerfModel(PerfModel): 22 | """ 23 | A no-op model that returns the maximum perf among all shards. Here, no-op 24 | means we estimate the performance of a model without actually running it. 25 | """ 26 | 27 | def __init__(self, topology: Topology) -> None: 28 | self._topology = topology 29 | 30 | def rate(self, plan: List[ShardingOption]) -> float: 31 | perfs = [0] * self._topology.world_size 32 | for sharding_option in plan: 33 | for shard in sharding_option.shards: 34 | # pyre-ignore [6]: Expected `typing_extensions.SupportsIndex` 35 | perfs[shard.rank] += cast(Perf, shard.perf).total 36 | 37 | return max(perfs) 38 | 39 | 40 | class NoopStorageModel(PerfModel): 41 | """ 42 | A no-op model that returns the maximum hbm usage among all shards. Here, no-op 43 | means we estimate the performance of a model without actually running it. 44 | """ 45 | 46 | def __init__(self, topology: Topology) -> None: 47 | self._topology = topology 48 | 49 | def rate(self, plan: List[ShardingOption]) -> float: 50 | hbms = [0] * self._topology.world_size 51 | for sharding_option in plan: 52 | for shard in sharding_option.shards: 53 | # pyre-ignore [6]: Expected `typing_extensions.SupportsIndex` 54 | hbms[shard.rank] += cast(Storage, shard.storage).hbm 55 | 56 | return max(hbms) 57 | --------------------------------------------------------------------------------