├── examples
├── __init__.py
├── ray
│ ├── __init__.py
│ ├── requirements.txt
│ ├── compute_world_size.py
│ └── README.md
├── zch
│ ├── __init__.py
│ └── docs
│ │ └── mpzch_module_dataflow.png
├── golden_training
│ ├── __init__.py
│ ├── tests
│ │ ├── __init__.py
│ │ └── test_train_dlrm.py
│ └── README.md
├── transfer_learning
│ ├── __init__.py
│ └── README.md
├── dlrm
│ └── README.MD
├── retrieval
│ ├── __init__.py
│ ├── data
│ │ ├── __init__.py
│ │ └── dataloader.py
│ ├── modules
│ │ ├── __init__.py
│ │ └── tests
│ │ │ └── __init__.py
│ └── tests
│ │ ├── __init__.py
│ │ ├── test_two_tower_retrieval.py
│ │ └── test_two_tower_train.py
├── prediction
│ └── __init__.py
├── bert4rec
│ └── models
│ │ └── tests
│ │ └── test_bert4rec.py
└── nvt_dataloader
│ ├── aws_component.py
│ └── README.md
├── version.txt
├── torchrec
├── metrics
│ └── __init__.py
├── schema
│ └── __init__.py
├── models
│ ├── tests
│ │ └── __init__.py
│ ├── experimental
│ │ └── __init__.py
│ └── __init__.py
├── optim
│ ├── tests
│ │ ├── __init__.py
│ │ └── test_optim.py
│ ├── test_utils
│ │ └── __init__.py
│ ├── fused.py
│ └── __init__.py
├── sparse
│ ├── tests
│ │ └── __init__.py
│ ├── __init__.py
│ └── tensor_dict.py
├── datasets
│ ├── scripts
│ │ ├── __init__.py
│ │ ├── nvt
│ │ │ ├── utils
│ │ │ │ ├── __init__.py
│ │ │ │ ├── dask.py
│ │ │ │ └── criteo_constant.py
│ │ │ ├── Dockerfile
│ │ │ └── nvt_preproc.sh
│ │ └── tests
│ │ │ └── test_npy_preproc_criteo.py
│ ├── tests
│ │ └── __init__.py
│ ├── test_utils
│ │ └── __init__.py
│ └── __init__.py
├── distributed
│ ├── tests
│ │ ├── __init__.py
│ │ ├── test_fused_optimizer.py
│ │ ├── test_model_parallel_nccl.py
│ │ ├── test_model_parallel_gloo_gpu.py
│ │ ├── test_model_parallel_gloo_gpu_single_rank.py
│ │ ├── test_model_parallel_gloo.py
│ │ └── test_awaitable.py
│ ├── benchmark
│ │ ├── __init__.py
│ │ ├── benchmark_zch
│ │ │ ├── data
│ │ │ │ ├── __init__.py
│ │ │ │ ├── configs
│ │ │ │ │ └── criteo_kaggle.yaml
│ │ │ │ ├── Readme.md
│ │ │ │ └── preprocess
│ │ │ │ │ └── __init__.py
│ │ │ ├── figures
│ │ │ │ ├── eval_metrics_auc.png
│ │ │ │ ├── eval_metrics_mae.png
│ │ │ │ ├── eval_metrics_mse.png
│ │ │ │ └── eval_metrics_ne.png
│ │ │ └── models
│ │ │ │ ├── models
│ │ │ │ └── __init__.py
│ │ │ │ ├── __init__.py
│ │ │ │ └── configs
│ │ │ │ ├── dlrmv3.yaml
│ │ │ │ └── dlrmv2.yaml
│ │ ├── README.md
│ │ └── yaml
│ │ │ ├── sparse_data_dist_ssd.yml
│ │ │ ├── sparse_data_dist_emo.yml
│ │ │ ├── sparse_data_dist_base.yml
│ │ │ ├── base_pipeline_light.yml
│ │ │ └── sparse_data_dist_base_vbe.yml
│ ├── composable
│ │ ├── __init__.py
│ │ └── tests
│ │ │ ├── __init__.py
│ │ │ └── test_table_batched_embedding_slice.py
│ ├── sharding
│ │ ├── __init__.py
│ │ └── twcw_sharding.py
│ ├── test_utils
│ │ └── __init__.py
│ ├── logging_handlers.py
│ ├── model_tracker
│ │ ├── trackers
│ │ │ └── __init__.py
│ │ └── __init__.py
│ ├── torchrec_logging_handlers.py
│ ├── global_settings.py
│ ├── planner
│ │ ├── __init__.py
│ │ └── perf_models.py
│ └── train_pipeline
│ │ └── __init__.py
├── modules
│ ├── tests
│ │ ├── __init__.py
│ │ ├── test_activation.py
│ │ └── test_code_quality.py
│ ├── README.rst
│ ├── object_pool.py
│ ├── __init__.py
│ ├── pruning_logger.py
│ └── activation.py
├── inference
│ ├── inference_legacy
│ │ ├── src
│ │ │ ├── Executer2.cpp
│ │ │ └── Validation.cpp
│ │ ├── include
│ │ │ └── torchrec
│ │ │ │ └── inference
│ │ │ │ ├── JaggedTensor.h
│ │ │ │ ├── Validation.h
│ │ │ │ ├── TestUtils.h
│ │ │ │ ├── ShardedTensor.h
│ │ │ │ ├── ExceptionHandler.h
│ │ │ │ └── Exception.h
│ │ ├── __init__.py
│ │ ├── protos
│ │ │ └── predictor.proto
│ │ └── tests
│ │ │ ├── ValidationTest.cpp
│ │ │ ├── test_modules.py
│ │ │ └── predict_module_tests.py
│ ├── __init__.py
│ ├── include
│ │ └── torchrec
│ │ │ └── inference
│ │ │ ├── JaggedTensor.h
│ │ │ ├── Validation.h
│ │ │ ├── TestUtils.h
│ │ │ ├── ShardedTensor.h
│ │ │ ├── ExceptionHandler.h
│ │ │ ├── Exception.h
│ │ │ └── ResultSplit.h
│ └── protos
│ │ └── predictor.proto
├── csrc
│ ├── CMakeLists.txt
│ └── dynamic_embedding
│ │ ├── details
│ │ ├── notification.cpp
│ │ ├── types.h
│ │ ├── notification.h
│ │ ├── redis
│ │ │ └── CMakeLists.txt
│ │ ├── lxu_strategy.h
│ │ ├── bits_op.h
│ │ ├── io_parameter.h
│ │ ├── io_registry.h
│ │ ├── bitmap.h
│ │ ├── bitmap_impl.h
│ │ └── naive_id_transformer.h
│ │ ├── CMakeLists.txt
│ │ ├── bind.cpp
│ │ └── id_transformer_wrapper.h
├── pt2
│ └── __init__.py
├── utils
│ └── __init__.py
├── fx
│ ├── __init__.py
│ └── tests
│ │ └── test_tracer.py
├── __init__.py
├── quant
│ ├── __init__.py
│ └── tests
│ │ └── test_tensor_types.py
├── ir
│ ├── types.py
│ └── schema.py
├── streamable.py
└── types.py
├── contrib
└── dynamic_embedding
│ ├── src
│ ├── __init__.py
│ ├── CMakeLists.txt
│ ├── tde
│ │ ├── details
│ │ │ ├── move_only_function.cpp
│ │ │ ├── naive_id_transformer.cpp
│ │ │ ├── cacheline_id_transformer.cpp
│ │ │ ├── redis_io.h
│ │ │ ├── notification_test.cpp
│ │ │ ├── move_only_function_test.cpp
│ │ │ ├── notification.cpp
│ │ │ ├── bits_op_test.cpp
│ │ │ ├── notification.h
│ │ │ ├── id_transformer_variant_test.cpp
│ │ │ ├── bits_op.h
│ │ │ ├── redis_io.cpp
│ │ │ ├── random_bits_generator_benchmark.cpp
│ │ │ ├── mixed_lfu_lru_strategy.cpp
│ │ │ ├── cacheline_id_transformer_benchmark.cpp
│ │ │ ├── naive_id_transformer_benchmark.cpp
│ │ │ ├── mixed_lfu_lru_strategy_benchmark.cpp
│ │ │ ├── url_test.cpp
│ │ │ ├── move_only_function.h
│ │ │ ├── id_transformer_variant_impl.h
│ │ │ └── random_bits_generator.h
│ │ ├── notification.h
│ │ ├── tensor_list.h
│ │ ├── id_transformer.h
│ │ └── bind.cpp
│ └── torchrec_dynamic_embedding
│ │ ├── __init__.py
│ │ ├── distributed
│ │ └── __init__.py
│ │ ├── utils.py
│ │ ├── tensor_list.py
│ │ └── id_transformer.py
│ ├── tests
│ ├── __init__.py
│ ├── memory_io
│ │ └── CMakeLists.txt
│ └── utils.py
│ ├── .gitignore
│ ├── tools
│ ├── repair_wheel.sh
│ ├── build_wheels.sh
│ └── before_linux_build.sh
│ ├── CMakeLists.txt
│ └── setup.py
├── docs
├── .gitignore
├── source
│ ├── _static
│ │ └── img
│ │ │ ├── sharding.png
│ │ │ ├── model_parallel.png
│ │ │ ├── torchrec_forward.png
│ │ │ ├── full_training_loop.png
│ │ │ ├── fused_embedding_tables.png
│ │ │ ├── fused_backward_optimizer.png
│ │ │ └── card-background.svg
│ ├── _templates
│ │ └── layout.html
│ ├── model-parallel-api-reference.rst
│ ├── inference-api-reference.rst
│ ├── datatypes-api-reference.rst
│ ├── modules-api-reference.rst
│ └── planner-api-reference.rst
├── requirements.txt
├── README.md
├── Makefile
└── make.bat
├── .github
├── scripts
│ ├── tests_to_skip.txt
│ └── install_libs.sh
└── workflows
│ ├── pre-commit.yaml
│ ├── pyre.yml
│ ├── validate-nightly-binaries.yml
│ └── validate-binaries.yml
├── pyproject.toml
├── install-requirements.txt
├── benchmarks
├── EBC_benchmarks_dlrm_emb.png
└── cpp
│ ├── CMakeLists.txt
│ └── dynamic_embedding
│ ├── CMakeLists.txt
│ ├── random_bits_generator_benchmark.cpp
│ ├── naive_id_transformer_benchmark.cpp
│ └── mixed_lfu_lru_strategy_benchmark.cpp
├── rfc
├── RFC-0001-assets
│ ├── cache-consistency.png
│ └── logical-cache-consistency.png
└── RFC-0002-assets
│ ├── kv_tbe_prefetch_workflow.png
│ ├── kv_tbe_training_high_level.png
│ └── kv_tbe_pipeline_prefetching.png
├── test
└── cpp
│ ├── CMakeLists.txt
│ └── dynamic_embedding
│ ├── notification_test.cpp
│ ├── redis
│ ├── CMakeLists.txt
│ └── url_test.cpp
│ ├── bits_op_test.cpp
│ └── CMakeLists.txt
├── requirements.txt
├── .pyre_configuration
├── .pre-commit-config.yaml
├── .lintrunner.toml
├── tools
└── lint
│ └── utils.py
├── CMakeLists.txt
├── LICENSE
└── CONTRIBUTING.md
/examples/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/version.txt:
--------------------------------------------------------------------------------
1 | 1.5.0a0
2 |
--------------------------------------------------------------------------------
/examples/ray/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/examples/zch/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/torchrec/metrics/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/torchrec/schema/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/torchrec/models/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/torchrec/optim/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/torchrec/sparse/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/examples/golden_training/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/examples/transfer_learning/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/torchrec/datasets/scripts/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/torchrec/datasets/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/torchrec/distributed/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/torchrec/modules/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/contrib/dynamic_embedding/src/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/contrib/dynamic_embedding/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/examples/golden_training/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/torchrec/datasets/test_utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/torchrec/distributed/benchmark/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/torchrec/distributed/composable/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/torchrec/distributed/sharding/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/torchrec/distributed/test_utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/torchrec/models/experimental/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/.gitignore:
--------------------------------------------------------------------------------
1 | src/pytorch-sphinx-theme/
2 |
--------------------------------------------------------------------------------
/torchrec/datasets/scripts/nvt/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/torchrec/distributed/composable/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/torchrec/distributed/tests/test_fused_optimizer.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/torchrec/inference/inference_legacy/src/Executer2.cpp:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/torchrec/distributed/benchmark/benchmark_zch/data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.github/scripts/tests_to_skip.txt:
--------------------------------------------------------------------------------
1 | _disabled_in_oss_compatibility
2 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.usort]
2 |
3 | first_party_detection = false
4 |
--------------------------------------------------------------------------------
/contrib/dynamic_embedding/src/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | add_subdirectory(3rd)
2 | add_subdirectory(tde)
3 |
--------------------------------------------------------------------------------
/contrib/dynamic_embedding/src/tde/details/move_only_function.cpp:
--------------------------------------------------------------------------------
1 | #include "move_only_function.h"
2 |
--------------------------------------------------------------------------------
/contrib/dynamic_embedding/src/tde/details/naive_id_transformer.cpp:
--------------------------------------------------------------------------------
1 | #include "naive_id_transformer.h"
2 |
--------------------------------------------------------------------------------
/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/__init__.py:
--------------------------------------------------------------------------------
1 | from .dataloader import save, wrap
2 |
--------------------------------------------------------------------------------
/contrib/dynamic_embedding/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | dist
3 | *egg-info
4 | cmake-build*
5 | _skbuild
6 | wheelhouse
7 |
--------------------------------------------------------------------------------
/install-requirements.txt:
--------------------------------------------------------------------------------
1 | fbgemm-gpu
2 | tensordict
3 | torchmetrics==1.0.3
4 | tqdm
5 | pyre-extensions
6 | iopath
7 |
--------------------------------------------------------------------------------
/contrib/dynamic_embedding/src/tde/details/cacheline_id_transformer.cpp:
--------------------------------------------------------------------------------
1 | #include "tde/details/cacheline_id_transformer.h"
2 |
--------------------------------------------------------------------------------
/benchmarks/EBC_benchmarks_dlrm_emb.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/meta-pytorch/torchrec/HEAD/benchmarks/EBC_benchmarks_dlrm_emb.png
--------------------------------------------------------------------------------
/docs/source/_static/img/sharding.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/meta-pytorch/torchrec/HEAD/docs/source/_static/img/sharding.png
--------------------------------------------------------------------------------
/docs/source/_static/img/model_parallel.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/meta-pytorch/torchrec/HEAD/docs/source/_static/img/model_parallel.png
--------------------------------------------------------------------------------
/examples/zch/docs/mpzch_module_dataflow.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/meta-pytorch/torchrec/HEAD/examples/zch/docs/mpzch_module_dataflow.png
--------------------------------------------------------------------------------
/rfc/RFC-0001-assets/cache-consistency.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/meta-pytorch/torchrec/HEAD/rfc/RFC-0001-assets/cache-consistency.png
--------------------------------------------------------------------------------
/docs/source/_static/img/torchrec_forward.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/meta-pytorch/torchrec/HEAD/docs/source/_static/img/torchrec_forward.png
--------------------------------------------------------------------------------
/docs/source/_static/img/full_training_loop.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/meta-pytorch/torchrec/HEAD/docs/source/_static/img/full_training_loop.png
--------------------------------------------------------------------------------
/examples/golden_training/README.md:
--------------------------------------------------------------------------------
1 | # Golden Training Example
2 |
3 | ## Running
4 | `torchx run -s local_cwd dist.ddp -j 1x2 --script train_dlrm.py`
5 |
--------------------------------------------------------------------------------
/rfc/RFC-0002-assets/kv_tbe_prefetch_workflow.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/meta-pytorch/torchrec/HEAD/rfc/RFC-0002-assets/kv_tbe_prefetch_workflow.png
--------------------------------------------------------------------------------
/contrib/dynamic_embedding/src/tde/details/redis_io.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | namespace tde::details {
4 |
5 | extern void RegisterRedisIO();
6 |
7 | }
8 |
--------------------------------------------------------------------------------
/docs/source/_static/img/fused_embedding_tables.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/meta-pytorch/torchrec/HEAD/docs/source/_static/img/fused_embedding_tables.png
--------------------------------------------------------------------------------
/rfc/RFC-0001-assets/logical-cache-consistency.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/meta-pytorch/torchrec/HEAD/rfc/RFC-0001-assets/logical-cache-consistency.png
--------------------------------------------------------------------------------
/rfc/RFC-0002-assets/kv_tbe_training_high_level.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/meta-pytorch/torchrec/HEAD/rfc/RFC-0002-assets/kv_tbe_training_high_level.png
--------------------------------------------------------------------------------
/docs/source/_static/img/fused_backward_optimizer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/meta-pytorch/torchrec/HEAD/docs/source/_static/img/fused_backward_optimizer.png
--------------------------------------------------------------------------------
/rfc/RFC-0002-assets/kv_tbe_pipeline_prefetching.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/meta-pytorch/torchrec/HEAD/rfc/RFC-0002-assets/kv_tbe_pipeline_prefetching.png
--------------------------------------------------------------------------------
/examples/dlrm/README.MD:
--------------------------------------------------------------------------------
1 | # TorchRec DLRM Example
2 |
3 | See [Facebookresearch/dlrm torchrec implementation](https://github.com/facebookresearch/dlrm/tree/main/torchrec_dlrm/)
4 |
--------------------------------------------------------------------------------
/examples/ray/requirements.txt:
--------------------------------------------------------------------------------
1 | --pre torchrec -f https://download.pytorch.org/whl/torchrec/index.html
2 | torchx-nightly
3 | torch --extra-index-url https://download.pytorch.org/whl/cu113
4 |
--------------------------------------------------------------------------------
/torchrec/distributed/benchmark/benchmark_zch/figures/eval_metrics_auc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/meta-pytorch/torchrec/HEAD/torchrec/distributed/benchmark/benchmark_zch/figures/eval_metrics_auc.png
--------------------------------------------------------------------------------
/torchrec/distributed/benchmark/benchmark_zch/figures/eval_metrics_mae.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/meta-pytorch/torchrec/HEAD/torchrec/distributed/benchmark/benchmark_zch/figures/eval_metrics_mae.png
--------------------------------------------------------------------------------
/torchrec/distributed/benchmark/benchmark_zch/figures/eval_metrics_mse.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/meta-pytorch/torchrec/HEAD/torchrec/distributed/benchmark/benchmark_zch/figures/eval_metrics_mse.png
--------------------------------------------------------------------------------
/torchrec/distributed/benchmark/benchmark_zch/figures/eval_metrics_ne.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/meta-pytorch/torchrec/HEAD/torchrec/distributed/benchmark/benchmark_zch/figures/eval_metrics_ne.png
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | sphinx==5.0.0
2 | pyre-extensions
3 | sphinx-design
4 | sphinx_copybutton
5 | # torch
6 | # PyTorch Theme
7 | -e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
8 |
--------------------------------------------------------------------------------
/docs/source/_templates/layout.html:
--------------------------------------------------------------------------------
1 | {% extends "!layout.html" %}
2 |
3 | {% block footer %}
4 | {{ super() }}
5 |
6 |
9 |
10 | {% endblock %}
11 |
--------------------------------------------------------------------------------
/examples/retrieval/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
--------------------------------------------------------------------------------
/contrib/dynamic_embedding/tests/memory_io/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | # A dummy ParamServer IO plugin. It just store params into memory
2 | add_library(memory_io SHARED memory_io.cpp)
3 | set_target_properties(memory_io PROPERTIES PREFIX "")
4 | target_compile_features(memory_io PUBLIC cxx_std_17)
5 |
--------------------------------------------------------------------------------
/examples/retrieval/data/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
--------------------------------------------------------------------------------
/examples/retrieval/modules/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
--------------------------------------------------------------------------------
/examples/retrieval/tests/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
--------------------------------------------------------------------------------
/test/cpp/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | add_subdirectory(dynamic_embedding)
8 |
--------------------------------------------------------------------------------
/benchmarks/cpp/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | add_subdirectory(dynamic_embedding)
8 |
--------------------------------------------------------------------------------
/examples/retrieval/modules/tests/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
--------------------------------------------------------------------------------
/torchrec/csrc/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | add_subdirectory(dynamic_embedding)
8 |
--------------------------------------------------------------------------------
/torchrec/distributed/benchmark/benchmark_zch/data/configs/criteo_kaggle.yaml:
--------------------------------------------------------------------------------
1 | dataset_path: "/home/lizhouyu/oss_github/dlrm/torchrec_dlrm/criteo_1tb/criteo_kaggle_processed"
2 | batch_size: 4096
3 | seed: 0
4 | multitask_configs:
5 | - task_name: is_click
6 | task_weight: 1
7 | task_type: classification
8 |
--------------------------------------------------------------------------------
/torchrec/datasets/scripts/nvt/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvcr.io/nvidia/merlin/merlin-pytorch-training:nightly
2 |
3 | RUN conda install -y pytorch cudatoolkit=11.3 -c pytorch-nightly \
4 | && pip install --pre torchrec_nightly -f https://download.pytorch.org/whl/nightly/torchrec_nightly/index.html
5 |
6 | WORKDIR /app
7 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | black
2 | click
3 | cmake
4 | fbgemm-gpu>=1.4.0
5 | hypothesis==6.70.1
6 | importlib-metadata
7 | iopath
8 | numpy
9 | pandas
10 | pyre-extensions
11 | scikit-build
12 | tensordict
13 | torchmetrics==1.0.3
14 | torchx
15 | tqdm
16 | usort
17 | parameterized
18 | PyYAML
19 | psutil
20 | expecttest>=0.3.0
21 |
--------------------------------------------------------------------------------
/docs/source/model-parallel-api-reference.rst:
--------------------------------------------------------------------------------
1 | Model Parallel
2 | ----------------------------------
3 |
4 | ``DistributedModelParallel`` is the main API for distributed training with TorchRec optimizations.
5 |
6 |
7 | .. automodule:: torchrec.distributed.model_parallel
8 |
9 | .. autoclass:: DistributedModelParallel
10 | :members:
11 |
--------------------------------------------------------------------------------
/torchrec/pt2/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 |
10 | # __init__ for python module packaging
11 |
--------------------------------------------------------------------------------
/torchrec/utils/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 |
10 | from . import experimental # noqa
11 |
--------------------------------------------------------------------------------
/torchrec/inference/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 |
10 | from . import model_packager, modules # noqa # noqa
11 |
--------------------------------------------------------------------------------
/contrib/dynamic_embedding/src/tde/details/notification_test.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include "gtest/gtest.h"
3 | #include "tde/details/notification.h"
4 | namespace tde::details {
5 | TEST(TDE, notification) {
6 | Notification notification;
7 | std::thread th([&] { notification.Done(); });
8 | notification.Wait();
9 | th.join();
10 | }
11 | } // namespace tde::details
12 |
--------------------------------------------------------------------------------
/contrib/dynamic_embedding/src/tde/details/move_only_function_test.cpp:
--------------------------------------------------------------------------------
1 | #include "gtest/gtest.h"
2 | #include "tde/details/move_only_function.h"
3 | namespace tde::details {
4 |
5 | TEST(tde, move_only_function) {
6 | MoveOnlyFunction foo = +[] { return 0; };
7 | ASSERT_EQ(foo(), 0);
8 | ASSERT_TRUE(foo);
9 | foo = {};
10 | ASSERT_FALSE(foo);
11 | }
12 |
13 | } // namespace tde::details
14 |
--------------------------------------------------------------------------------
/contrib/dynamic_embedding/tools/repair_wheel.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | set -xe
3 | WHEEL_FILE=$1
4 | DEST_DIR=$2
5 |
6 | CUDA_SUFFIX=cu$(echo "$CUDA_VERSION" | sed 's#\.##g')
7 | WHEEL_FILENAME=$(basename "${WHEEL_FILE}")
8 | DEST_FILENAME=$(echo "${WHEEL_FILENAME}" | sed -r 's#(torchrec_dynamic_embedding-[0-9]+\.[0-9]+\.[0-9]+)#\1'"+${CUDA_SUFFIX}#g")
9 | mv "${WHEEL_FILE}" "${DEST_DIR}/${DEST_FILENAME}"
10 |
--------------------------------------------------------------------------------
/.pyre_configuration:
--------------------------------------------------------------------------------
1 | {
2 | "exclude": [
3 | ".*/pyre-check/stubs/.*",
4 | ".*/torchrec/datasets*",
5 | ".*/torchrec/models*",
6 | ".*/torchrec/inference/client.py"
7 | ],
8 | "site_package_search_strategy": "all",
9 | "source_directories": [
10 | {
11 | "import_root": ".",
12 | "source": "torchrec"
13 | }
14 | ],
15 | "strict": true,
16 | "version": "0.0.101729681899"
17 | }
18 |
--------------------------------------------------------------------------------
/examples/prediction/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 |
10 |
11 | def main() -> None:
12 | """DOC_STRING"""
13 |
14 |
15 | if __name__ == "__main__":
16 | main()
17 |
--------------------------------------------------------------------------------
/torchrec/distributed/benchmark/benchmark_zch/models/models/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 | from .dlrmv2 import make_model_dlrmv2
10 | from .dlrmv3 import make_model_dlrmv3
11 |
--------------------------------------------------------------------------------
/torchrec/modules/README.rst:
--------------------------------------------------------------------------------
1 | .. fbmeta::
2 | hide_page_title=false
3 |
4 | TorchRec Module API Documentation
5 | ==================================
6 |
7 | This contains the information about the internal TorchRec Common Modules.
8 | With TOC tree, but it isn't showing up....
9 |
10 | TorchRec Common Modules
11 |
12 |
13 | .. tocTree::
14 | :glob
15 |
16 |
17 |
18 | .. automodule:: torchrec.modules
19 | :members:
20 |
--------------------------------------------------------------------------------
/docs/source/_static/img/card-background.svg:
--------------------------------------------------------------------------------
1 |
2 |
14 |
--------------------------------------------------------------------------------
/torchrec/distributed/logging_handlers.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | import logging
9 |
10 |
11 | __all__: list[str] = []
12 |
13 | _log_handlers: dict[str, logging.Handler] = {
14 | "default": logging.NullHandler(),
15 | }
16 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v4.5.0
4 | hooks:
5 | - id: check-toml
6 | - id: check-yaml
7 | exclude: packaging/.*
8 | - id: end-of-file-fixer
9 |
10 | - repo: https://github.com/omnilib/ufmt
11 | rev: v2.5.1
12 | hooks:
13 | - id: ufmt
14 | additional_dependencies:
15 | - black == 24.2.0
16 | - usort == 1.0.8.post1
17 |
--------------------------------------------------------------------------------
/contrib/dynamic_embedding/src/tde/notification.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include "tde/details/notification.h"
4 |
5 | namespace tde {
6 |
7 | class Notification : public torch::CustomClassHolder {
8 | public:
9 | void Done() {
10 | return notification_.Done();
11 | }
12 | void Wait() {
13 | return notification_.Wait();
14 | }
15 |
16 | private:
17 | details::Notification notification_;
18 | };
19 |
20 | } // namespace tde
21 |
--------------------------------------------------------------------------------
/torchrec/distributed/model_tracker/trackers/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 |
10 | """MPZCH Raw ID Tracker
11 | """
12 |
13 | from torchrec.distributed.model_tracker.trackers.raw_id_tracker import ( # noqa
14 | RawIdTracker,
15 | )
16 |
--------------------------------------------------------------------------------
/torchrec/distributed/benchmark/benchmark_zch/data/Readme.md:
--------------------------------------------------------------------------------
1 | # Datasets for zero collision hash benchmark
2 |
3 | ## Folder structure
4 | - `configs/`: Configs for each dataset, named as `{dataset_name}.json`
5 | - `preprocess`: Include scripts to preprocess the dataset to make the returned dataset in the format of
6 | - batch
7 | - dense_features
8 | - sparse_features
9 | - labels
10 | - `get_dataloader.py`: the entry point to get the dataloader for each dataset
11 |
--------------------------------------------------------------------------------
/torchrec/distributed/tests/test_model_parallel_nccl.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 |
10 | from torchrec.distributed.test_utils.test_model_parallel import ModelParallelBase
11 |
12 |
13 | class ModelParallelTestNccl(ModelParallelBase):
14 | pass
15 |
--------------------------------------------------------------------------------
/contrib/dynamic_embedding/src/tde/details/notification.cpp:
--------------------------------------------------------------------------------
1 | #include "notification.h"
2 |
3 | namespace tde::details {
4 | void Notification::Done() {
5 | {
6 | std::lock_guard guard(mtx_);
7 | set_ = true;
8 | }
9 | cv_.notify_all();
10 | }
11 | void Notification::Wait() {
12 | std::unique_lock lock(mtx_);
13 | cv_.wait(lock, [this] { return set_; });
14 | }
15 |
16 | void Notification::Clear() {
17 | set_ = false;
18 | }
19 | } // namespace tde::details
20 |
--------------------------------------------------------------------------------
/torchrec/distributed/benchmark/benchmark_zch/models/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 | from .apply_optimizers import (
10 | apply_dense_optimizers,
11 | apply_sparse_optimizers,
12 | combine_optimizers,
13 | )
14 | from .make_model import make_model
15 |
--------------------------------------------------------------------------------
/contrib/dynamic_embedding/src/tde/details/bits_op_test.cpp:
--------------------------------------------------------------------------------
1 | #include "gtest/gtest.h"
2 | #include "tde/details/bits_op.h"
3 |
4 | namespace tde::details {
5 | TEST(TDE, bits_op_Clz) {
6 | ASSERT_EQ(Clz(int32_t(0x7FFFFFFF)), 1);
7 | ASSERT_EQ(Clz(int64_t(0x7FFFFFFFFFFFFFFF)), 1);
8 | ASSERT_EQ(Clz(int8_t(0x7F)), 1);
9 | }
10 |
11 | TEST(TDE, bits_op_Ctz) {
12 | ASSERT_EQ(Ctz(int32_t(0x2)), 1);
13 | ASSERT_EQ(Ctz(int64_t(0xF00)), 8);
14 | ASSERT_EQ(Ctz(int8_t(0x4)), 2);
15 | }
16 | } // namespace tde::details
17 |
--------------------------------------------------------------------------------
/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/distributed/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | #!/usr/bin/env python3
9 |
10 | from .comm import (
11 | broadcast_ids_to_evict,
12 | broadcast_transform_result,
13 | gather_global_ids,
14 | scatter_cache_ids,
15 | )
16 |
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | Docs
2 | ==========
3 |
4 |
5 | ## Building the docs
6 |
7 | To build and preview the docs run the following commands:
8 |
9 | ```bash
10 | cd docs
11 | pip3 install -r requirements.txt
12 | make html
13 | python3 -m http.server 8082 --bind ::
14 | ```
15 |
16 | Now you should be able to view the docs in your browser at the link provided in your terminal.
17 |
18 | To reload the preview after making changes, rerun:
19 |
20 | ```bash
21 | make html
22 | python3 -m http.server 8082 --bind ::
23 | ```
24 |
--------------------------------------------------------------------------------
/torchrec/distributed/torchrec_logging_handlers.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 |
9 | import logging
10 |
11 | from torchrec.distributed.logging_handlers import _log_handlers
12 |
13 | TORCHREC_LOGGER_NAME = "torchrec"
14 |
15 | _log_handlers.update({TORCHREC_LOGGER_NAME: logging.NullHandler()})
16 |
--------------------------------------------------------------------------------
/torchrec/fx/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 |
10 | """Torchrec Tracer
11 |
12 | Custom FX tracer for torchrec
13 |
14 | See `Torch.FX documentation `_
15 | """
16 |
17 | from torchrec.fx.tracer import symbolic_trace, Tracer # noqa
18 |
--------------------------------------------------------------------------------
/contrib/dynamic_embedding/src/tde/details/notification.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 |
5 | namespace tde::details {
6 |
7 | /**
8 | * Multi-thread notification
9 | */
10 | class Notification {
11 | public:
12 | void Done();
13 | void Wait();
14 |
15 | /**
16 | * Clear the set status.
17 | *
18 | * NOTE: Clear is not thread-safe.
19 | */
20 | void Clear();
21 |
22 | private:
23 | bool set_{false};
24 | std::mutex mtx_;
25 | std::condition_variable cv_;
26 | };
27 |
28 | } // namespace tde::details
29 |
--------------------------------------------------------------------------------
/torchrec/distributed/benchmark/benchmark_zch/data/preprocess/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 | from .criteo_kaggle import get_criteo_kaggle_dataloader
10 | from .kuairand_1k import get_kuairand_1k_dataloader
11 | from .kuairand_27k import get_kuairand_27k_dataloader
12 | from .movielens_1m import get_movielens_1m_dataloader
13 |
--------------------------------------------------------------------------------
/torchrec/distributed/global_settings.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 |
10 | PROPOGATE_DEVICE: bool = False
11 |
12 |
13 | def set_propogate_device(val: bool) -> None:
14 | global PROPOGATE_DEVICE
15 | PROPOGATE_DEVICE = val
16 |
17 |
18 | def get_propogate_device() -> bool:
19 | global PROPOGATE_DEVICE
20 | return PROPOGATE_DEVICE
21 |
--------------------------------------------------------------------------------
/contrib/dynamic_embedding/tools/build_wheels.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | cd $(dirname $0)/../
3 | export CIBW_BEFORE_BUILD="tools/before_linux_build.sh"
4 |
5 | # Use env CIBW_BUILD="cp*-manylinux_x86_64" tools/build_wheels.sh to build
6 | # all kinds of CPython.
7 | export CIBW_BUILD=${CIBW_BUILD:-"cp39-manylinux_x86_64"}
8 |
9 | export CIBW_MANYLINUX_X86_64_IMAGE=${CIBW_MANYLINUX_X86_64_IMAGE:-"manylinux_2_28"}
10 |
11 | # Do not auditwheels since tde uses torch's shared libraries.
12 | export CIBW_REPAIR_WHEEL_COMMAND="tools/repair_wheel.sh {wheel} {dest_dir}"
13 |
14 | cibuildwheel --platform linux --archs x86_64
15 |
--------------------------------------------------------------------------------
/torchrec/distributed/tests/test_model_parallel_gloo_gpu.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 |
10 | from torchrec.distributed.test_utils.test_model_parallel import ModelParallelBase
11 |
12 | # GPU tests for Gloo.
13 |
14 |
15 | class ModelParallelTestGloo(ModelParallelBase):
16 | def setUp(self, backend: str = "gloo") -> None:
17 | super().setUp(backend=backend)
18 |
--------------------------------------------------------------------------------
/test/cpp/dynamic_embedding/notification_test.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | * All rights reserved.
4 | *
5 | * This source code is licensed under the BSD-style license found in the
6 | * LICENSE file in the root directory of this source tree.
7 | */
8 |
9 | #include
10 | #include
11 | #include
12 |
13 | namespace torchrec {
14 | TEST(TDE, notification) {
15 | Notification notification;
16 | std::thread th([&] { notification.done(); });
17 | notification.wait();
18 | th.join();
19 | }
20 | } // namespace torchrec
21 |
--------------------------------------------------------------------------------
/examples/transfer_learning/README.md:
--------------------------------------------------------------------------------
1 | # Transfer Learning example
2 |
3 | This examples showcases training a distributed model using TorchRec. The embeddings are initialized with pretrained values (assumed to be loaded from storage, such as parquet). The value is large enough that we use the `share_memory_` API to load the tensors from shared memory.
4 |
5 | See [`torch.multiprocessing`](https://pytorch.org/docs/stable/multiprocessing.html) and [best practices](https://pytorch.org/docs/1.6.0/notes/multiprocessing.html?highlight=multiprocessing) for more information on shared memory.
6 |
7 | ## Running
8 |
9 | `python train_from_pretrained_embedding.py`
10 |
--------------------------------------------------------------------------------
/test/cpp/dynamic_embedding/redis/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | function(add_redis_test NAME)
8 | add_executable(${NAME} ${ARGN})
9 | target_link_libraries(${NAME} redis_io tde_cpp_objs gtest gtest_main)
10 | add_test(NAME ${NAME} COMMAND ${NAME})
11 | endfunction()
12 |
13 | # TODO: Need start a empty redis-server
14 | # on 127.0.0.1:6379 before run *redis*_test.
15 | add_redis_test(redis_io_test redis_io_test.cpp)
16 | add_redis_test(url_test url_test.cpp)
17 |
--------------------------------------------------------------------------------
/torchrec/optim/test_utils/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 |
10 | from typing import Any
11 |
12 | from torchrec.optim.keyed import KeyedOptimizer
13 |
14 |
15 | class DummyKeyedOptimizer(KeyedOptimizer):
16 | def __init__(self, *args: Any, **kwargs: Any) -> None:
17 | super().__init__(*args, **kwargs)
18 |
19 | # pyre-ignore[2]
20 | def step(self, closure: Any) -> None:
21 | pass # Override NotImplementedError.
22 |
--------------------------------------------------------------------------------
/test/cpp/dynamic_embedding/redis/url_test.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | * All rights reserved.
4 | *
5 | * This source code is licensed under the BSD-style license found in the
6 | * LICENSE file in the root directory of this source tree.
7 | */
8 |
9 | #include
10 | #include
11 |
12 | namespace torchrec::url_parser::rules {
13 |
14 | TEST(TDE, url) {
15 | auto url = parse_url("www.qq.com/?a=b&&c=d");
16 | ASSERT_EQ(url.host, "www.qq.com");
17 | ASSERT_TRUE(url.param.has_value());
18 | ASSERT_EQ("a=b&&c=d", url.param.value());
19 | }
20 |
21 | } // namespace torchrec::url_parser::rules
22 |
--------------------------------------------------------------------------------
/.lintrunner.toml:
--------------------------------------------------------------------------------
1 | [[linter]]
2 | code = 'BLACK'
3 | include_patterns = ['**/*.py']
4 | command = [
5 | 'python3',
6 | 'tools/lint/black_linter.py',
7 | '--',
8 | '@{{PATHSFILE}}'
9 | ]
10 | init_command = [
11 | 'python3',
12 | 'tools/lint/pip_init.py',
13 | '--dry-run={{DRYRUN}}',
14 | 'black==24.2.0',
15 | ]
16 | is_formatter = true
17 |
18 | [[linter]]
19 | code = 'USORT'
20 | include_patterns = ['**/*.py']
21 | command = [
22 | 'python3',
23 | 'tools/lint/usort_linter.py',
24 | '--',
25 | '@{{PATHSFILE}}'
26 | ]
27 | init_command = [
28 | 'python3',
29 | 'tools/lint/pip_init.py',
30 | '--dry-run={{DRYRUN}}',
31 | 'usort==1.0.8',
32 | ]
33 | is_formatter = true
34 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/torchrec/csrc/dynamic_embedding/details/notification.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | * All rights reserved.
4 | *
5 | * This source code is licensed under the BSD-style license found in the
6 | * LICENSE file in the root directory of this source tree.
7 | */
8 |
9 | #include "notification.h"
10 |
11 | namespace torchrec {
12 | void Notification::done() {
13 | {
14 | std::lock_guard guard(mu_);
15 | set_ = true;
16 | }
17 | cv_.notify_all();
18 | }
19 | void Notification::wait() {
20 | std::unique_lock lock(mu_);
21 | cv_.wait(lock, [this] { return set_; });
22 | }
23 |
24 | void Notification::clear() {
25 | set_ = false;
26 | }
27 | } // namespace torchrec
28 |
--------------------------------------------------------------------------------
/torchrec/inference/include/torchrec/inference/JaggedTensor.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | * All rights reserved.
4 | *
5 | * This source code is licensed under the BSD-style license found in the
6 | * LICENSE file in the root directory of this source tree.
7 | */
8 |
9 | #pragma once
10 |
11 | #include
12 | #include
13 |
14 | #include
15 |
16 | namespace torchrec {
17 |
18 | struct JaggedTensor {
19 | at::Tensor lengths;
20 | at::Tensor values;
21 | at::Tensor weights;
22 | };
23 |
24 | struct KeyedJaggedTensor {
25 | std::vector keys;
26 | at::Tensor lengths;
27 | at::Tensor values;
28 | at::Tensor weights;
29 | };
30 |
31 | } // namespace torchrec
32 |
--------------------------------------------------------------------------------
/contrib/dynamic_embedding/src/tde/details/id_transformer_variant_test.cpp:
--------------------------------------------------------------------------------
1 | #include "gtest/gtest.h"
2 | #include "tde/details/id_transformer_variant.h"
3 |
4 | namespace tde::details {
5 |
6 | TEST(TDE, CreateLXUStrategy) {
7 | auto strategy = IDTransformer::LXUStrategy(
8 | {{"min_used_freq_power", 6}, {"type", "mixed_lru_lfu"}});
9 | }
10 |
11 | TEST(TDE, IDTransformer) {
12 | IDTransformer transformer(1000, nlohmann::json::parse(R"(
13 | {
14 | "lxu_strategy": {"type": "mixed_lru_lfu"},
15 | "id_transformer": {"type": "naive"}
16 | }
17 | )"));
18 | std::vector vec{0, 1, 2};
19 | std::vector result;
20 | result.resize(vec.size());
21 | transformer.Transform(vec, result);
22 | }
23 |
24 | } // namespace tde::details
25 |
--------------------------------------------------------------------------------
/docs/source/inference-api-reference.rst:
--------------------------------------------------------------------------------
1 | Inference
2 | ----------------------------------
3 |
4 | TorchRec provides easy-to-use APIs for transforming an authored TorchRec model
5 | into an optimized inference model for distributed inference, via eager module swaps.
6 |
7 | This transforms TorchRec modules like ``EmbeddingBagCollection`` in the model to
8 | a quantized, sharded version that can be compiled using torch.fx and TorchScript
9 | for inference in a C++ environment.
10 |
11 | The intended use is calling ``quantize_inference_model`` on the model followed by
12 | ``shard_quant_model``.
13 |
14 | .. codeblock::
15 |
16 | .. automodule:: torchrec.inference.modules
17 |
18 | .. autofunction:: quantize_inference_model
19 | .. autofunction:: shard_quant_model
20 |
--------------------------------------------------------------------------------
/torchrec/inference/inference_legacy/include/torchrec/inference/JaggedTensor.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | * All rights reserved.
4 | *
5 | * This source code is licensed under the BSD-style license found in the
6 | * LICENSE file in the root directory of this source tree.
7 | */
8 |
9 | #pragma once
10 |
11 | #include
12 | #include
13 |
14 | #include
15 |
16 | namespace torchrec {
17 |
18 | struct JaggedTensor {
19 | at::Tensor lengths;
20 | at::Tensor values;
21 | at::Tensor weights;
22 | };
23 |
24 | struct KeyedJaggedTensor {
25 | std::vector keys;
26 | at::Tensor lengths;
27 | at::Tensor values;
28 | at::Tensor weights;
29 | };
30 |
31 | } // namespace torchrec
32 |
--------------------------------------------------------------------------------
/test/cpp/dynamic_embedding/bits_op_test.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | * All rights reserved.
4 | *
5 | * This source code is licensed under the BSD-style license found in the
6 | * LICENSE file in the root directory of this source tree.
7 | */
8 |
9 | #include
10 | #include
11 |
12 | namespace torchrec {
13 | TEST(TDE, bits_op_clz) {
14 | ASSERT_EQ(clz(int32_t(0x7FFFFFFF)), 1);
15 | ASSERT_EQ(clz(int64_t(0x7FFFFFFFFFFFFFFF)), 1);
16 | ASSERT_EQ(clz(int8_t(0x7F)), 1);
17 | }
18 |
19 | TEST(TDE, bits_op_ctz) {
20 | ASSERT_EQ(ctz(int32_t(0x2)), 1);
21 | ASSERT_EQ(ctz(int64_t(0xF00)), 8);
22 | ASSERT_EQ(ctz(int8_t(0x4)), 2);
23 | }
24 | } // namespace torchrec
25 |
--------------------------------------------------------------------------------
/torchrec/distributed/benchmark/benchmark_zch/models/configs/dlrmv3.yaml:
--------------------------------------------------------------------------------
1 | hstu_num_heads: 4 # 1 for hstu, 2 for hstu-large
2 | hstu_attn_num_layers: 3 # 2 for hstu, 8 for hstu-large
3 | hstu_embedding_table_dim: 256
4 | hstu_transducer_embedding_dim: 512
5 | hstu_attn_linear_dim: 128
6 | hstu_attn_qk_dim: 128
7 | hstu_input_dropout_ratio: 0.2
8 | hstu_linear_dropout_rate: 0.1
9 | causal_multitask_weights: 0.2
10 | num_embeddings: 100000
11 | embedding_module_attribute_path: "dlrm_hstu._embedding_collection" # the attribute path of embedding module after model
12 | managed_collision_module_attribute_path: "module.dlrm_hstu._embedding_collection.mc_embedding_collection._managed_collision_collection._managed_collision_modules" # the attribute path of managed collision module after model
13 |
--------------------------------------------------------------------------------
/contrib/dynamic_embedding/src/tde/tensor_list.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | namespace tde {
5 |
6 | class TensorList : public torch::CustomClassHolder {
7 | using Container = std::vector;
8 |
9 | public:
10 | TensorList() = default;
11 |
12 | void push_back(at::Tensor tensor) {
13 | tensors_.push_back(tensor);
14 | }
15 | int64_t size() const {
16 | return tensors_.size();
17 | }
18 | torch::Tensor& operator[](int64_t index) {
19 | return tensors_[index];
20 | }
21 |
22 | Container::const_iterator begin() const {
23 | return tensors_.begin();
24 | }
25 | Container::const_iterator end() const {
26 | return tensors_.end();
27 | }
28 |
29 | private:
30 | Container tensors_;
31 | };
32 |
33 | } // namespace tde
34 |
--------------------------------------------------------------------------------
/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | import torch.nn as nn
4 | from torchrec.distributed.types import ShardingPlan
5 |
6 |
7 | __all__ = []
8 |
9 |
10 | def _get_sharded_modules_recursive(
11 | module: nn.Module,
12 | path: str,
13 | plan: ShardingPlan,
14 | ) -> Dict[str, nn.Module]:
15 | """
16 | Get all sharded modules of module from `plan`.
17 | """
18 | params_plan = plan.get_plan_for_module(path)
19 | if params_plan:
20 | return {path: (module, params_plan)}
21 |
22 | res = {}
23 | for name, child in module.named_children():
24 | new_path = f"{path}.{name}" if path else name
25 | res.update(_get_sharded_modules_recursive(child, new_path, plan))
26 | return res
27 |
--------------------------------------------------------------------------------
/benchmarks/cpp/dynamic_embedding/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | function(add_tde_benchmark NAME)
8 | add_executable(${NAME} ${ARGN})
9 | target_link_libraries(${NAME} tde_cpp_objs benchmark::benchmark_main benchmark::benchmark)
10 | endfunction()
11 |
12 | add_tde_benchmark(naive_id_transformer_benchmark naive_id_transformer_benchmark.cpp)
13 | add_tde_benchmark(random_bits_generator_benchmark random_bits_generator_benchmark.cpp)
14 | add_tde_benchmark(mixed_lfu_lru_strategy_benchmark mixed_lfu_lru_strategy_benchmark.cpp)
15 | add_tde_benchmark(mixed_lfu_lru_strategy_evict_benchmark mixed_lfu_lru_strategy_evict_benchmark.cpp)
16 |
--------------------------------------------------------------------------------
/torchrec/csrc/dynamic_embedding/details/types.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | * All rights reserved.
4 | *
5 | * This source code is licensed under the BSD-style license found in the
6 | * LICENSE file in the root directory of this source tree.
7 | */
8 |
9 | #pragma once
10 |
11 | #include
12 | #include
13 | #include
14 |
15 | namespace torchrec {
16 |
17 | using lxu_record_t = uint32_t;
18 |
19 | struct record_t {
20 | int64_t global_id;
21 | int64_t cache_id;
22 | lxu_record_t lxu_record;
23 | };
24 |
25 | using iterator_t = std::function()>;
26 | using update_t =
27 | std::function)>;
28 | using fetch_t = std::function;
29 |
30 | } // namespace torchrec
31 |
--------------------------------------------------------------------------------
/.github/workflows/pre-commit.yaml:
--------------------------------------------------------------------------------
1 | name: pre-commit
2 |
3 | on:
4 | push:
5 | branches:
6 | # only run tests on main branch & nightly; release should be triggered manually
7 | - nightly
8 | - main
9 | tags:
10 | # Release candidate tags look like: v1.11.0-rc1
11 | - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+
12 | pull_request:
13 |
14 | jobs:
15 | pre-commit:
16 | runs-on: ubuntu-latest
17 | steps:
18 | - name: Setup Python
19 | uses: actions/setup-python@v5
20 | with:
21 | python-version: 3.12
22 | architecture: x64
23 | packages: |
24 | ufmt==2.5.1
25 | black==24.2.0
26 | usort==1.0.8
27 | - name: Checkout Torchrec
28 | uses: actions/checkout@v4
29 | - name: Run pre-commit
30 | uses: pre-commit/action@v3.0.1
31 |
--------------------------------------------------------------------------------
/contrib/dynamic_embedding/tests/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import torch.distributed as dist
5 |
6 |
7 | __all__ = []
8 |
9 |
10 | MEMORY_IO_REGISTERED = False
11 |
12 |
13 | def register_memory_io():
14 | global MEMORY_IO_REGISTERED
15 | if not MEMORY_IO_REGISTERED:
16 | mem_io_path = os.getenv("TDE_MEMORY_IO_PATH")
17 | if mem_io_path is None:
18 | raise RuntimeError("env TDE_MEMORY_IO_PATH must set for unittest")
19 |
20 | torch.ops.tde.register_io(mem_io_path)
21 | MEMORY_IO_REGISTERED = True
22 |
23 |
24 | def init_dist():
25 | if not dist.is_initialized():
26 | os.environ["RANK"] = "0"
27 | os.environ["WORLD_SIZE"] = "1"
28 | os.environ["MASTER_ADDR"] = "127.0.0.1"
29 | os.environ["MASTER_PORT"] = "13579"
30 | dist.init_process_group("nccl")
31 |
--------------------------------------------------------------------------------
/docs/source/datatypes-api-reference.rst:
--------------------------------------------------------------------------------
1 | Data Types
2 | -------------------
3 |
4 |
5 | TorchRec contains data types for representing embedding, otherwise known as sparse features.
6 | Sparse features are typically indices that are meant to be fed into embedding tables. For a given
7 | batch, the number of embedding lookup indices are variable. Therefore, there is a need for a **jagged**
8 | dimension to represent the variable amount of embedding lookup indices for a batch.
9 |
10 | This section covers the classes for the 3 TorchRec data types for representing sparse features:
11 | **JaggedTensor**, **KeyedJaggedTensor**, and **KeyedTensor**.
12 |
13 | .. automodule:: torchrec.sparse.jagged_tensor
14 |
15 | .. autoclass:: JaggedTensor
16 | :members:
17 |
18 | .. autoclass:: KeyedJaggedTensor
19 | :members:
20 |
21 | .. autoclass:: KeyedTensor
22 | :members:
23 |
--------------------------------------------------------------------------------
/test/cpp/dynamic_embedding/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | function(add_tde_test NAME)
8 | add_executable(${NAME} ${ARGN})
9 | target_link_libraries(${NAME} tde_cpp_objs gtest gtest_main)
10 | add_test(NAME ${NAME} COMMAND ${NAME})
11 | endfunction()
12 |
13 | add_tde_test(bits_op_test bits_op_test.cpp)
14 | add_tde_test(naive_id_transformer_test naive_id_transformer_test.cpp)
15 | add_tde_test(random_bits_generator_test random_bits_generator_test.cpp)
16 | add_tde_test(mixed_lfu_lru_strategy_test mixed_lfu_lru_strategy_test.cpp)
17 | add_tde_test(notification_test notification_test.cpp)
18 |
19 | if (BUILD_REDIS_IO)
20 | add_subdirectory(redis)
21 | endif()
22 |
--------------------------------------------------------------------------------
/torchrec/distributed/tests/test_model_parallel_gloo_gpu_single_rank.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 |
10 | from torchrec.distributed.test_utils.test_model_parallel_base import (
11 | ModelParallelSparseOnlyBase,
12 | ModelParallelStateDictBase,
13 | )
14 |
15 | # Single rank GPU tests for Gloo.
16 |
17 |
18 | class ModelParallelStateDictTestGloo(ModelParallelStateDictBase):
19 | def setUp(self, backend: str = "gloo") -> None:
20 | super().setUp(backend=backend)
21 |
22 |
23 | class ModelParallelSparseOnlyTestGloo(ModelParallelSparseOnlyBase):
24 | def setUp(self, backend: str = "gloo") -> None:
25 | super().setUp(backend=backend)
26 |
--------------------------------------------------------------------------------
/.github/workflows/pyre.yml:
--------------------------------------------------------------------------------
1 | name: Pyre Check
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | pull_request:
7 |
8 | jobs:
9 | pyre-check:
10 | runs-on: ubuntu-22.04
11 | defaults:
12 | run:
13 | shell: bash -el {0}
14 | steps:
15 | - uses: conda-incubator/setup-miniconda@v2
16 | with:
17 | python-version: 3.9
18 | - name: Checkout Torchrec
19 | uses: actions/checkout@v4
20 | - name: Install dependencies
21 | run: >
22 | pip install torch --index-url https://download.pytorch.org/whl/nightly/cpu &&
23 | pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/cpu &&
24 | pip install -r requirements.txt &&
25 | pip install pyre-check-nightly==$(cat .pyre_configuration | grep version | awk '{print $2}' | sed 's/\"//g')
26 | - name: Pyre check
27 | run: pyre check
28 |
--------------------------------------------------------------------------------
/.github/workflows/validate-nightly-binaries.yml:
--------------------------------------------------------------------------------
1 | # Scheduled validation of the nightly binaries
2 | name: validate-nightly-binaries
3 |
4 | on:
5 | schedule:
6 | # At 5:30 pm UTC (7:30 am PDT)
7 | - cron: "30 17 * * *"
8 | # Have the ability to trigger this job manually through the API
9 | workflow_dispatch:
10 | push:
11 | branches:
12 | - main
13 | paths:
14 | - '.github/workflows/validate-nightly-binaries.yml'
15 | - '.github/workflows/validate-binaries.yml'
16 | - '.github/scripts/validate-binaries.sh'
17 | pull_request:
18 | paths:
19 | - '.github/workflows/validate-nightly-binaries.yml'
20 | - '.github/workflows/validate-binaries.yml'
21 | - '.github/scripts/validate_binaries.sh'
22 | - '.github/scripts/filter.py'
23 | jobs:
24 | nightly:
25 | uses: ./.github/workflows/validate-binaries.yml
26 | with:
27 | channel: nightly
28 |
--------------------------------------------------------------------------------
/torchrec/csrc/dynamic_embedding/details/notification.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | * All rights reserved.
4 | *
5 | * This source code is licensed under the BSD-style license found in the
6 | * LICENSE file in the root directory of this source tree.
7 | */
8 |
9 | #pragma once
10 | #include
11 | #include
12 | #include
13 |
14 | namespace torchrec {
15 |
16 | /**
17 | * Multi-thread notification
18 | */
19 | class Notification : public torch::CustomClassHolder {
20 | public:
21 | Notification() = default;
22 |
23 | void done();
24 | void wait();
25 |
26 | /**
27 | * Clear the set status.
28 | *
29 | * NOTE: Clear is not thread-safe.
30 | */
31 | void clear();
32 |
33 | private:
34 | bool set_{false};
35 | std::mutex mu_;
36 | std::condition_variable cv_;
37 | };
38 |
39 | } // namespace torchrec
40 |
--------------------------------------------------------------------------------
/docs/source/modules-api-reference.rst:
--------------------------------------------------------------------------------
1 | Modules
2 | ----------------------------------
3 |
4 | Standard TorchRec modules represent collections of embedding tables:
5 |
6 | * ``EmbeddingBagCollection`` is a collection of ``torch.nn.EmbeddingBag``
7 | * ``EmbeddingCollection`` is a collection of ``torch.nn.Embedding``
8 |
9 | These modules are constructed through standardized config classes:
10 |
11 | * ``EmbeddingBagConfig`` for ``EmbeddingBagCollection``
12 | * ``EmbeddingConfig`` for ``EmbeddingCollection``
13 |
14 | .. automodule:: torchrec.modules.embedding_configs
15 |
16 | .. autoclass:: EmbeddingBagConfig
17 | :show-inheritance:
18 |
19 | .. autoclass:: EmbeddingConfig
20 | :show-inheritance:
21 |
22 | .. autoclass:: BaseEmbeddingConfig
23 |
24 | .. automodule:: torchrec.modules.embedding_modules
25 |
26 | .. autoclass:: EmbeddingBagCollection
27 | :members:
28 |
29 | .. autoclass:: EmbeddingCollection
30 | :members:
31 |
--------------------------------------------------------------------------------
/torchrec/modules/object_pool.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 |
10 | import abc
11 | from typing import Generic, TypeVar
12 |
13 | import torch
14 |
15 | T = TypeVar("T")
16 |
17 |
18 | class ObjectPool(abc.ABC, torch.nn.Module, Generic[T]):
19 | """
20 | Interface for TensorPool and KeyedJaggedTensorPool
21 |
22 | Defines methods for lookup, update and obtaining pool size
23 | """
24 |
25 | @abc.abstractmethod
26 | def lookup(self, ids: torch.Tensor) -> T:
27 | pass
28 |
29 | @abc.abstractmethod
30 | def update(self, ids: torch.Tensor, values: T) -> None:
31 | pass
32 |
33 | @abc.abstractproperty
34 | def pool_size(self) -> int:
35 | pass
36 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.https://www.sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/contrib/dynamic_embedding/src/tde/details/bits_op.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 |
5 | namespace tde::details {
6 |
7 | namespace bits_impl {
8 | template
9 | struct Clz {
10 | int operator()(T v) const;
11 | };
12 |
13 | template
14 | struct Ctz {
15 | int operator()(T v) const;
16 | };
17 | } // namespace bits_impl
18 |
19 | /**
20 | * Returns the number of leading 0-bits in t, starting at the most significant
21 | * bit position. If t is 0, the result is undefined.
22 | */
23 | template
24 | inline int Clz(T t) {
25 | bits_impl::Clz clz;
26 | return clz(t);
27 | }
28 |
29 | /**
30 | * Returns the number of trailing 0-bits in t, starting at the least significant
31 | * bit position. If t is 0, the result is undefined.
32 | */
33 | template
34 | inline int Ctz(T t) {
35 | bits_impl::Ctz ctz;
36 | return ctz(t);
37 | }
38 |
39 | } // namespace tde::details
40 |
--------------------------------------------------------------------------------
/examples/retrieval/tests/test_two_tower_retrieval.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 |
10 | import unittest
11 |
12 | import torch
13 |
14 | from torchrec.test_utils import skip_if_asan
15 |
16 | # @manual=//torchrec/github/examples/retrieval:two_tower_retrieval_lib
17 | from ..two_tower_retrieval import infer
18 |
19 |
20 | class InferTest(unittest.TestCase):
21 | @skip_if_asan
22 | # pyre-ignore[56]
23 | @unittest.skipIf(
24 | torch.cuda.device_count() <= 1,
25 | "Not enough GPUs, this test requires at least two GPUs",
26 | )
27 | def test_infer_function(self) -> None:
28 | infer(
29 | embedding_dim=16,
30 | layer_sizes=[16],
31 | world_size=2,
32 | )
33 |
--------------------------------------------------------------------------------
/tools/lint/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 |
9 | import os
10 | from enum import Enum
11 | from typing import NamedTuple, Optional
12 |
13 | IS_WINDOWS: bool = os.name == "nt"
14 |
15 |
16 | class LintSeverity(str, Enum):
17 | ERROR = "error"
18 | WARNING = "warning"
19 | ADVICE = "advice"
20 | DISABLED = "disabled"
21 |
22 |
23 | class LintMessage(NamedTuple):
24 | path: Optional[str]
25 | line: Optional[int]
26 | char: Optional[int]
27 | code: str
28 | severity: LintSeverity
29 | name: str
30 | original: Optional[str]
31 | replacement: Optional[str]
32 | description: Optional[str]
33 |
34 |
35 | def as_posix(name: str) -> str:
36 | return name.replace("\\", "/") if IS_WINDOWS else name
37 |
--------------------------------------------------------------------------------
/contrib/dynamic_embedding/src/tde/details/redis_io.cpp:
--------------------------------------------------------------------------------
1 | #include "redis_io.h"
2 | #include "tde/details/io_registry.h"
3 | #include "tde/details/redis_io_v1.h"
4 |
5 | namespace tde::details {
6 |
7 | void RegisterRedisIO() {
8 | auto& reg = IORegistry::Instance();
9 |
10 | {
11 | IOProvider provider{};
12 | provider.type_ = "redis";
13 | provider.Initialize = +[](const char* cfg) -> void* {
14 | auto opt = redis_v1::Option::Parse(cfg);
15 | return new redis_v1::RedisV1(opt);
16 | };
17 | provider.Finalize =
18 | +[](void* inst) { delete reinterpret_cast(inst); };
19 | provider.Pull = +[](void* inst, IOPullParameter param) {
20 | reinterpret_cast(inst)->Pull(param);
21 | };
22 | provider.Push = +[](void* inst, IOPushParameter param) {
23 | reinterpret_cast(inst)->Push(param);
24 | };
25 | reg.Register(provider);
26 | }
27 | }
28 | } // namespace tde::details
29 |
--------------------------------------------------------------------------------
/torchrec/modules/tests/test_activation.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 |
10 | import unittest
11 |
12 | import torch
13 | from torchrec.fx import Tracer
14 | from torchrec.modules.activation import SwishLayerNorm
15 |
16 |
17 | class TestActivation(unittest.TestCase):
18 | def test_swish_takes_float(self) -> None:
19 | m = SwishLayerNorm([3, 4])
20 | input = torch.randn(2, 3, 4)
21 | output = m(input)
22 | norm = torch.nn.LayerNorm([3, 4])
23 | ref_output = input * torch.sigmoid(norm(input))
24 | self.assertTrue(torch.allclose(output, ref_output))
25 |
26 | def test_fx_script_swish(self) -> None:
27 | m = SwishLayerNorm(10)
28 |
29 | gm = torch.fx.GraphModule(m, Tracer().trace(m))
30 | torch.jit.script(gm)
31 |
--------------------------------------------------------------------------------
/torchrec/datasets/scripts/nvt/nvt_preproc.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | set -e
9 |
10 | INPUT_PATH="$1"
11 | BASE_OUTPUT_PATH="$2"
12 | BATCH_SIZE="$3"
13 | TEMP_PATH=""$BASE_OUTPUT_PATH"temp/"
14 | SRC_DIR=""$BASE_OUTPUT_PATH"criteo_preproc/"
15 | BINARY_OUTPUT_PATH=""$BASE_OUTPUT_PATH"criteo_binary/"
16 | FINAL_OUTPUT_PATH=""$BINARY_OUTPUT_PATH"split"
17 |
18 | python convert_tsv_to_parquet.py -i "$INPUT_PATH" -o "$BASE_OUTPUT_PATH"
19 | python process_criteo_parquet.py -b "$BASE_OUTPUT_PATH" -s
20 | python convert_parquet_to_binary.py --src_dir "$SRC_DIR" \
21 | --intermediate_dir "$TEMP_PATH" \
22 | --dst_dir "$BINARY_OUTPUT_PATH"
23 | python split_binary_dataset.py --input_path "$BINARY_OUTPUT_PATH" --output_path "$FINAL_OUTPUT_PATH" --batch_size "$BATCH_SIZE"
24 |
--------------------------------------------------------------------------------
/torchrec/models/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 |
10 | """Torchrec Models
11 |
12 | Torchrec provides the architecture for two popular recsys models;
13 | `DeepFM `_ and `DLRM (Deep Learning Recommendation Model)
14 | `_.
15 |
16 | Along with the overall model, the individual architectures of each layer are also
17 | provided (e.g. `SparseArch`, `DenseArch`, `InteractionArch`, and `OverArch`).
18 |
19 | Examples can be found within each model.
20 |
21 | The following notation is used throughout the documentation for the models:
22 |
23 | * F: number of sparse features
24 | * D: embedding_dimension of sparse features
25 | * B: batch size
26 | * num_features: number of dense features
27 | """
28 |
29 | from torchrec.models import deepfm, dlrm # noqa
30 |
--------------------------------------------------------------------------------
/torchrec/csrc/dynamic_embedding/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | add_library(tde_cpp_objs
8 | OBJECT
9 | bind.cpp
10 | id_transformer_wrapper.cpp
11 | ps.cpp
12 | details/clz_impl.cpp
13 | details/ctz_impl.cpp
14 | details/random_bits_generator.cpp
15 | details/io_registry.cpp
16 | details/io.cpp
17 | details/notification.cpp)
18 |
19 | if (BUILD_REDIS_IO)
20 | add_subdirectory(details/redis)
21 | endif()
22 |
23 | target_include_directories(tde_cpp_objs PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../../../)
24 | target_include_directories(tde_cpp_objs PUBLIC ${TORCH_INCLUDE_DIRS})
25 | target_link_libraries(tde_cpp_objs PUBLIC ${TORCH_LIBRARIES})
26 | target_compile_options(tde_cpp_objs PUBLIC -fPIC)
27 | target_link_libraries(tde_cpp_objs PUBLIC ${CMAKE_DL_LIBS})
28 |
--------------------------------------------------------------------------------
/contrib/dynamic_embedding/src/tde/details/random_bits_generator_benchmark.cpp:
--------------------------------------------------------------------------------
1 | #include "benchmark/benchmark.h"
2 | #include "tde/details/random_bits_generator.h"
3 |
4 | namespace tde::details {
5 |
6 | void BMRandomBitsGenerator(benchmark::State& state) {
7 | auto n = state.range(0);
8 | auto n_bits_limit = state.range(1);
9 | BitScanner scanner(n);
10 | std::mt19937_64 engine((std::random_device())());
11 | std::uniform_int_distribution dist(1, n_bits_limit);
12 | uint16_t n_bits = dist(engine);
13 | for (auto _ : state) {
14 | if (n_bits != 0) {
15 | scanner.ResetArray([&](tcb::span span) {
16 | for (auto& v : span) {
17 | v = engine();
18 | }
19 | });
20 | } else {
21 | n_bits = dist(engine);
22 | }
23 | benchmark::DoNotOptimize(scanner.IsNextNBitsAllZero(n_bits));
24 | }
25 | }
26 |
27 | BENCHMARK(BMRandomBitsGenerator)
28 | ->ArgNames({"n", "limit"})
29 | ->Args({1024, 32})
30 | ->Unit(benchmark::kMillisecond)
31 | ->Iterations(1024 * 1024);
32 |
33 | } // namespace tde::details
34 |
--------------------------------------------------------------------------------
/torchrec/inference/include/torchrec/inference/Validation.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | * All rights reserved.
4 | *
5 | * This source code is licensed under the BSD-style license found in the
6 | * LICENSE file in the root directory of this source tree.
7 | */
8 |
9 | #pragma once
10 |
11 | #include "torchrec/inference/Types.h"
12 |
13 | namespace torchrec {
14 |
15 | // Returns whether sparse features (KeyedJaggedTensor) are valid.
16 | // Currently validates:
17 | // 1. Whether sum(lengths) == size(values)
18 | // 2. Whether there are negative values in lengths
19 | // 3. If weights is present, whether sum(lengths) == size(weights)
20 | bool validateSparseFeatures(
21 | at::Tensor& values,
22 | at::Tensor& lengths,
23 | std::optional maybeWeights = std::nullopt);
24 |
25 | // Returns whether dense features are valid.
26 | // Currently validates:
27 | // 1. Whether the size of values is divisable by batch size (request level)
28 | bool validateDenseFeatures(at::Tensor& values, size_t batchSize);
29 |
30 | } // namespace torchrec
31 |
--------------------------------------------------------------------------------
/torchrec/inference/inference_legacy/include/torchrec/inference/Validation.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | * All rights reserved.
4 | *
5 | * This source code is licensed under the BSD-style license found in the
6 | * LICENSE file in the root directory of this source tree.
7 | */
8 |
9 | #pragma once
10 |
11 | #include "torchrec/inference/Types.h"
12 |
13 | namespace torchrec {
14 |
15 | // Returns whether sparse features (KeyedJaggedTensor) are valid.
16 | // Currently validates:
17 | // 1. Whether sum(lengths) == size(values)
18 | // 2. Whether there are negative values in lengths
19 | // 3. If weights is present, whether sum(lengths) == size(weights)
20 | bool validateSparseFeatures(
21 | at::Tensor& values,
22 | at::Tensor& lengths,
23 | std::optional maybeWeights = std::nullopt);
24 |
25 | // Returns whether dense features are valid.
26 | // Currently validates:
27 | // 1. Whether the size of values is divisable by batch size (request level)
28 | bool validateDenseFeatures(at::Tensor& values, size_t batchSize);
29 |
30 | } // namespace torchrec
31 |
--------------------------------------------------------------------------------
/contrib/dynamic_embedding/src/tde/id_transformer.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 | #include "tde/details/id_transformer_variant.h"
5 | #include "tde/tensor_list.h"
6 |
7 | namespace tde {
8 |
9 | struct TransformResult : public torch::CustomClassHolder {
10 | TransformResult(bool success, torch::Tensor ids_to_fetch)
11 | : success_(success), ids_to_fetch_(ids_to_fetch) {}
12 | bool success_;
13 | torch::Tensor ids_to_fetch_;
14 | };
15 |
16 | class IDTransformer : public torch::CustomClassHolder {
17 | public:
18 | IDTransformer(int64_t num_embeddings, nlohmann::json json);
19 | c10::intrusive_ptr Transform(
20 | c10::intrusive_ptr global_ids,
21 | c10::intrusive_ptr cache_ids,
22 | int64_t time);
23 |
24 | torch::Tensor Evict(int64_t num_to_evict);
25 | torch::Tensor Save();
26 |
27 | private:
28 | std::mutex mu_;
29 | details::IDTransformer transformer_;
30 | std::vector ids_to_fetch_;
31 | int64_t time_;
32 | int64_t last_save_time_;
33 | };
34 |
35 | } // namespace tde
36 |
--------------------------------------------------------------------------------
/torchrec/distributed/benchmark/benchmark_zch/models/configs/dlrmv2.yaml:
--------------------------------------------------------------------------------
1 | dense_arch_layer_sizes:
2 | - 512
3 | - 256
4 | - 64
5 | over_arch_layer_sizes:
6 | - 512
7 | - 512
8 | - 256
9 | - 1
10 | embedding_dim: 64
11 | num_embeddings_per_feature:
12 | cat_0: 40000000
13 | cat_1: 39060
14 | cat_2: 17295
15 | cat_3: 7424
16 | cat_4: 20265
17 | cat_5: 3
18 | cat_6: 7122
19 | cat_7: 1543
20 | cat_8: 63
21 | cat_9: 40000000
22 | cat_10: 3067956
23 | cat_11: 405282
24 | cat_12: 10
25 | cat_13: 2209
26 | cat_14: 11938
27 | cat_15: 155
28 | cat_16: 4
29 | cat_17: 976
30 | cat_18: 14
31 | cat_19: 40000000
32 | cat_20: 40000000
33 | cat_21: 40000000
34 | cat_22: 590152
35 | cat_23: 12973
36 | cat_24: 108
37 | cat_25: 36
38 | embedding_module_attribute_path: "dlrm.sparse_arch.embedding_bag_collection" # the attribute path after model
39 | managed_collision_module_attribute_path: "module.dlrm.sparse_arch.embedding_bag_collection.mc_embedding_bag_collection._managed_collision_collection._managed_collision_modules" # the attribute path of managed collision module after model
40 |
--------------------------------------------------------------------------------
/contrib/dynamic_embedding/src/tde/details/mixed_lfu_lru_strategy.cpp:
--------------------------------------------------------------------------------
1 | #include "mixed_lfu_lru_strategy.h"
2 | #include
3 | #include "c10/macros/Macros.h"
4 |
5 | namespace tde::details {
6 | MixedLFULRUStrategy::MixedLFULRUStrategy(uint16_t min_used_freq_power)
7 | : min_lfu_power_(min_used_freq_power), time_(new std::atomic()) {}
8 |
9 | void MixedLFULRUStrategy::UpdateTime(uint32_t time) {
10 | time_->store(time);
11 | }
12 |
13 | MixedLFULRUStrategy::lxu_record_t MixedLFULRUStrategy::Update(
14 | int64_t global_id,
15 | int64_t cache_id,
16 | std::optional val) {
17 | Record r{};
18 | r.time_ = time_->load();
19 |
20 | if (C10_UNLIKELY(!val.has_value())) {
21 | r.freq_power_ = min_lfu_power_;
22 | } else {
23 | auto freq_power = reinterpret_cast(&val.value())->freq_power_;
24 | bool should_carry = generator_.IsNextNBitsAllZero(freq_power);
25 | if (should_carry) {
26 | ++freq_power;
27 | }
28 | r.freq_power_ = freq_power;
29 | }
30 | return *reinterpret_cast(&r);
31 | }
32 |
33 | } // namespace tde::details
34 |
--------------------------------------------------------------------------------
/torchrec/distributed/tests/test_model_parallel_gloo.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 |
10 | from torchrec.distributed.test_utils.test_model_parallel import ModelParallelBase
11 | from torchrec.distributed.test_utils.test_model_parallel_base import (
12 | ModelParallelSparseOnlyBase,
13 | ModelParallelStateDictBase,
14 | )
15 |
16 | # CPU tests for Gloo.
17 |
18 |
19 | class ModelParallelTestGloo(ModelParallelBase):
20 | def setUp(self, backend: str = "gloo") -> None:
21 | super().setUp(backend=backend)
22 |
23 |
24 | class ModelParallelStateDictTestGloo(ModelParallelStateDictBase):
25 | def setUp(self, backend: str = "gloo") -> None:
26 | super().setUp(backend=backend)
27 |
28 |
29 | class ModelParallelSparseOnlyTestGloo(ModelParallelSparseOnlyBase):
30 | def setUp(self, backend: str = "gloo") -> None:
31 | super().setUp(backend=backend)
32 |
--------------------------------------------------------------------------------
/torchrec/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 |
10 | import torchrec.distributed # noqa
11 | import torchrec.quant # noqa
12 | from torchrec.fx import tracer # noqa
13 | from torchrec.modules.embedding_configs import ( # noqa
14 | DataType,
15 | EmbeddingBagConfig,
16 | EmbeddingConfig,
17 | PoolingType,
18 | )
19 | from torchrec.modules.embedding_modules import ( # noqa
20 | EmbeddingBagCollection,
21 | EmbeddingBagCollectionInterface,
22 | EmbeddingCollection,
23 | ) # noqa
24 | from torchrec.sparse.jagged_tensor import ( # noqa
25 | JaggedTensor,
26 | KeyedJaggedTensor,
27 | KeyedTensor,
28 | )
29 | from torchrec.streamable import Multistreamable, Pipelineable # noqa
30 |
31 | try:
32 | # pyre-ignore[21]
33 | # @manual=//torchrec/fb:version
34 | from .version import __version__, github_version # noqa
35 | except ImportError:
36 | pass
37 |
--------------------------------------------------------------------------------
/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/tensor_list.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import os
8 | from typing import List
9 |
10 | import torch
11 |
12 | try:
13 | torch.ops.load_library(os.path.join(os.path.dirname(__file__), "tde_cpp.so"))
14 | except Exception as ex:
15 | print(f"File tde_cpp.so not found {ex}")
16 |
17 |
18 | __all__ = []
19 |
20 |
21 | class TensorList:
22 | def __init__(self, tensors: List[torch.Tensor]):
23 | self.tensor_list = torch.classes.tde.TensorList()
24 | for tensor in tensors:
25 | # tensor.data will allow inplace ops during autograd.
26 | # https://discuss.pytorch.org/t/disable-in-place-correctness-version-check-any-other-workaround/90738/2
27 | self.tensor_list.append(tensor.data)
28 |
29 | def __len__(self):
30 | return len(self.tensor_list)
31 |
32 | def __getitem__(self, i):
33 | return self.tensor_list[i]
34 |
--------------------------------------------------------------------------------
/torchrec/csrc/dynamic_embedding/details/redis/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | FetchContent_Declare(
8 | hiredis
9 | GIT_REPOSITORY https://github.com/redis/hiredis.git
10 | GIT_TAG 06be7ff312a78f69237e5963cc7d24bc84104d3b
11 | )
12 |
13 | FetchContent_GetProperties(hiredis)
14 | if(NOT hiredis_POPULATED)
15 | # Do not include hiredis in install targets
16 | FetchContent_Populate(hiredis)
17 | set(DISABLE_TESTS ON CACHE BOOL "Disable tests for hiredis")
18 | add_subdirectory(
19 | ${hiredis_SOURCE_DIR} ${hiredis_BINARY_DIR} EXCLUDE_FROM_ALL)
20 | endif()
21 |
22 | add_library(redis_io SHARED redis_io.cpp)
23 | target_include_directories(
24 | redis_io PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../../../../../)
25 | target_include_directories(redis_io PUBLIC ${TORCH_INCLUDE_DIRS})
26 | target_compile_options(redis_io PUBLIC -fPIC)
27 | target_link_libraries(redis_io PUBLIC hiredis::hiredis_static)
28 |
--------------------------------------------------------------------------------
/contrib/dynamic_embedding/tools/before_linux_build.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | set -xe
9 |
10 | distro=rhel7
11 | arch=x86_64
12 | CUDA_VERSION="${CUDA_VERSION:-11.8}"
13 |
14 | CUDA_MAJOR_VERSION=$(echo "${CUDA_VERSION}" | tr '.' ' ' | awk '{print $1}')
15 | CUDA_MINOR_VERSION=$(echo "${CUDA_VERSION}" | tr '.' ' ' | awk '{print $2}')
16 |
17 | yum install -y yum-utils
18 | yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$distro/$arch/cuda-$distro.repo
19 | yum install -y \
20 | cuda-toolkit-"${CUDA_MAJOR_VERSION}"-"${CUDA_MINOR_VERSION}" \
21 | libcudnn8-devel
22 | ln -s cuda-"${CUDA_MAJOR_VERSION}"."${CUDA_MINOR_VERSION}" /usr/local/cuda
23 |
24 | pipx install cmake
25 | pipx install ninja
26 | python -m pip install scikit-build
27 | python -m pip install --pre torch --extra-index-url \
28 | https://download.pytorch.org/whl/nightly/cu"${CUDA_MAJOR_VERSION}""${CUDA_MINOR_VERSION}"
29 |
--------------------------------------------------------------------------------
/torchrec/distributed/benchmark/README.md:
--------------------------------------------------------------------------------
1 | # TorchRec Benchmark
2 | ## benchmark_train_pipeline usage
3 | - internal:
4 | ```
5 | buck2 run @fbcode//mode/opt fbcode//torchrec/distributed/benchmark:benchmark_train_pipeline -- \
6 | --yaml_config=fbcode/torchrec/distributed/benchmark/yaml/sparse_data_dist_base.yml \
7 | --name=sparse_data_dist_base_$(hg whereami | cut -c 1-10 || echo $USER) # overrides the yaml config
8 | ```
9 | - oss:
10 | ```
11 | python -m torchrec.distributed.benchmark.benchmark_train_pipeline \
12 | --yaml_config=fbcode/torchrec/distributed/benchmark/yaml/sparse_data_dist_base.yml \
13 | --name=sparse_data_dist_base_$(git rev-parse --short HEAD || echo $USER) # overrides the yaml config
14 | ```
15 |
16 | ## benchmark_comms usage
17 | - internal:
18 | ```
19 | buck2 run @fbcode//mode/opt fbcode//torchrec/distributed/benchmark:benchmark_comms -- \
20 | a2a_single --name=a2a_sync_base-$(hg whereami | cut -c 1-10) --memory_snapshot=true
21 | ```
22 | - oss:
23 | ```
24 | python -m torchrec.distributed.benchmark.benchmark_comms \
25 | a2a_single --name=a2a_sync_base-$(git rev-parse --short HEAD || echo $USER) --memory_snapshot=true
26 | ```
27 |
--------------------------------------------------------------------------------
/torchrec/inference/include/torchrec/inference/TestUtils.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | * All rights reserved.
4 | *
5 | * This source code is licensed under the BSD-style license found in the
6 | * LICENSE file in the root directory of this source tree.
7 | */
8 |
9 | #pragma once
10 |
11 | #include
12 | #include
13 |
14 | #include
15 |
16 | #include "torchrec/inference/JaggedTensor.h"
17 | #include "torchrec/inference/Types.h"
18 |
19 | namespace torchrec {
20 |
21 | std::shared_ptr createRequest(at::Tensor denseTensor);
22 |
23 | std::shared_ptr
24 | createRequest(size_t batchSize, size_t numFeatures, const JaggedTensor& jagged);
25 |
26 | std::shared_ptr
27 | createRequest(size_t batchSize, size_t numFeatures, at::Tensor embedding);
28 |
29 | JaggedTensor createJaggedTensor(const std::vector>& input);
30 |
31 | c10::List createIValueList(
32 | const std::vector>& input);
33 |
34 | at::Tensor createEmbeddingTensor(
35 | const std::vector>& input);
36 |
37 | } // namespace torchrec
38 |
--------------------------------------------------------------------------------
/torchrec/inference/inference_legacy/include/torchrec/inference/TestUtils.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | * All rights reserved.
4 | *
5 | * This source code is licensed under the BSD-style license found in the
6 | * LICENSE file in the root directory of this source tree.
7 | */
8 |
9 | #pragma once
10 |
11 | #include
12 | #include
13 |
14 | #include
15 |
16 | #include "torchrec/inference/JaggedTensor.h"
17 | #include "torchrec/inference/Types.h"
18 |
19 | namespace torchrec {
20 |
21 | std::shared_ptr createRequest(at::Tensor denseTensor);
22 |
23 | std::shared_ptr
24 | createRequest(size_t batchSize, size_t numFeatures, const JaggedTensor& jagged);
25 |
26 | std::shared_ptr
27 | createRequest(size_t batchSize, size_t numFeatures, at::Tensor embedding);
28 |
29 | JaggedTensor createJaggedTensor(const std::vector>& input);
30 |
31 | c10::List createIValueList(
32 | const std::vector>& input);
33 |
34 | at::Tensor createEmbeddingTensor(
35 | const std::vector>& input);
36 |
37 | } // namespace torchrec
38 |
--------------------------------------------------------------------------------
/torchrec/distributed/benchmark/yaml/sparse_data_dist_ssd.yml:
--------------------------------------------------------------------------------
1 | # this is a very basic sparse data dist config
2 | # runs on 2 ranks, showing traces with reasonable workloads
3 | RunOptions:
4 | world_size: 2
5 | num_batches: 5
6 | num_benchmarks: 2
7 | sharding_type: table_wise
8 | profile_dir: "."
9 | name: "sparse_data_dist_base"
10 | # export_stacks: True # enable this to export stack traces
11 | PipelineConfig:
12 | pipeline: "sparse"
13 | EmbeddingTablesConfig:
14 | num_unweighted_features: 100
15 | num_weighted_features: 100
16 | embedding_feature_dim: 256
17 | additional_tables:
18 | - - name: FP16_table
19 | embedding_dim: 512
20 | num_embeddings: 100_000
21 | feature_names: ["additional_0_0"]
22 | data_type: FP16
23 | - name: large_table
24 | embedding_dim: 256
25 | num_embeddings: 1_000_000
26 | feature_names: ["additional_0_1"]
27 | - []
28 | - - name: skipped_table
29 | embedding_dim: 128
30 | num_embeddings: 100_000
31 | feature_names: ["additional_2_1"]
32 | PlannerConfig:
33 | additional_constraints:
34 | large_table:
35 | compute_kernels: [key_value]
36 | sharding_types: [row_wise]
37 |
--------------------------------------------------------------------------------
/torchrec/distributed/planner/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 |
10 | """Torchrec Planner
11 |
12 | The planner provides the specifications necessary for a module to be sharded,
13 | considering the possible options to build an optimized plan.
14 |
15 | The features includes:
16 | - generating all possible sharding options.
17 | - estimating perf and storage for every shard.
18 | - estimating peak memory usage to eliminate sharding plans that might OOM.
19 | - customizability for parameter constraints, partitioning, proposers, or performance
20 | modeling.
21 | - automatically building and selecting an optimized sharding plan.
22 | """
23 |
24 | from torchrec.distributed.planner.planners import ( # noqa # noqa
25 | EmbeddingPlannerBase,
26 | EmbeddingShardingPlanner,
27 | )
28 | from torchrec.distributed.planner.types import ParameterConstraints, Topology # noqa
29 | from torchrec.distributed.planner.utils import bytes_to_gb, sharder_name # noqa
30 |
--------------------------------------------------------------------------------
/torchrec/inference/include/torchrec/inference/ShardedTensor.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | * All rights reserved.
4 | *
5 | * This source code is licensed under the BSD-style license found in the
6 | * LICENSE file in the root directory of this source tree.
7 | */
8 |
9 | #pragma once
10 |
11 | #include
12 |
13 | #include
14 |
15 | namespace torchrec {
16 |
17 | struct ShardMetadata {
18 | std::vector shard_offsets;
19 | std::vector shard_lengths;
20 |
21 | bool operator==(const ShardMetadata& other) const {
22 | return shard_offsets == other.shard_offsets &&
23 | shard_lengths == other.shard_lengths;
24 | }
25 | };
26 |
27 | struct Shard {
28 | ShardMetadata metadata;
29 | at::Tensor tensor;
30 | };
31 |
32 | struct ShardedTensorMetadata {
33 | std::vector shards_metadata;
34 | };
35 |
36 | struct ShardedTensor {
37 | std::vector sizes;
38 | std::vector local_shards;
39 | ShardedTensorMetadata metadata;
40 | };
41 |
42 | struct ReplicatedTensor {
43 | ShardedTensor local_replica;
44 | int64_t local_replica_id;
45 | int64_t replica_count;
46 | };
47 |
48 | } // namespace torchrec
49 |
--------------------------------------------------------------------------------
/torchrec/distributed/benchmark/yaml/sparse_data_dist_emo.yml:
--------------------------------------------------------------------------------
1 | # this is a very basic sparse data dist config
2 | # runs on 2 ranks, showing traces with reasonable workloads
3 | RunOptions:
4 | world_size: 2
5 | num_batches: 5
6 | num_benchmarks: 2
7 | sharding_type: table_wise
8 | profile_dir: "."
9 | name: "sparse_data_dist_base"
10 | # export_stacks: True # enable this to export stack traces
11 | PipelineConfig:
12 | pipeline: "sparse"
13 | EmbeddingTablesConfig:
14 | num_unweighted_features: 100
15 | num_weighted_features: 100
16 | embedding_feature_dim: 256
17 | additional_tables:
18 | - - name: FP16_table
19 | embedding_dim: 512
20 | num_embeddings: 100_000
21 | feature_names: ["additional_0_0"]
22 | data_type: FP16
23 | - name: large_table
24 | embedding_dim: 256
25 | num_embeddings: 1_000_000
26 | feature_names: ["additional_0_1"]
27 | - []
28 | - - name: skipped_table
29 | embedding_dim: 128
30 | num_embeddings: 100_000
31 | feature_names: ["additional_2_1"]
32 | PlannerConfig:
33 | additional_constraints:
34 | large_table:
35 | compute_kernels: [fused_uvm_caching]
36 | sharding_types: [row_wise]
37 |
--------------------------------------------------------------------------------
/torchrec/inference/inference_legacy/include/torchrec/inference/ShardedTensor.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | * All rights reserved.
4 | *
5 | * This source code is licensed under the BSD-style license found in the
6 | * LICENSE file in the root directory of this source tree.
7 | */
8 |
9 | #pragma once
10 |
11 | #include
12 |
13 | #include
14 |
15 | namespace torchrec {
16 |
17 | struct ShardMetadata {
18 | std::vector shard_offsets;
19 | std::vector shard_lengths;
20 |
21 | bool operator==(const ShardMetadata& other) const {
22 | return shard_offsets == other.shard_offsets &&
23 | shard_lengths == other.shard_lengths;
24 | }
25 | };
26 |
27 | struct Shard {
28 | ShardMetadata metadata;
29 | at::Tensor tensor;
30 | };
31 |
32 | struct ShardedTensorMetadata {
33 | std::vector shards_metadata;
34 | };
35 |
36 | struct ShardedTensor {
37 | std::vector sizes;
38 | std::vector local_shards;
39 | ShardedTensorMetadata metadata;
40 | };
41 |
42 | struct ReplicatedTensor {
43 | ShardedTensor local_replica;
44 | int64_t local_replica_id;
45 | int64_t replica_count;
46 | };
47 |
48 | } // namespace torchrec
49 |
--------------------------------------------------------------------------------
/torchrec/inference/include/torchrec/inference/ExceptionHandler.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | * All rights reserved.
4 | *
5 | * This source code is licensed under the BSD-style license found in the
6 | * LICENSE file in the root directory of this source tree.
7 | */
8 |
9 | #pragma once
10 |
11 | #include
12 | #include
13 |
14 | #include
15 |
16 | #include "torchrec/inference/Exception.h"
17 | #include "torchrec/inference/Types.h"
18 |
19 | namespace torchrec {
20 | template
21 | void handleRequestException(
22 | folly::Promise>& promise,
23 | const std::string& msg) {
24 | auto ex = folly::make_exception_wrapper(msg);
25 | auto response = std::make_unique();
26 | response->exception = std::move(ex);
27 | promise.setValue(std::move(response));
28 | }
29 |
30 | template
31 | void handleBatchException(
32 | std::vector& contexts,
33 | const std::string& msg) {
34 | for (auto& context : contexts) {
35 | handleRequestException(context.promise, msg);
36 | }
37 | }
38 |
39 | } // namespace torchrec
40 |
--------------------------------------------------------------------------------
/torchrec/distributed/benchmark/yaml/sparse_data_dist_base.yml:
--------------------------------------------------------------------------------
1 | # this is a very basic sparse data dist config
2 | # runs on 2 ranks, showing traces with reasonable workloads
3 | RunOptions:
4 | world_size: 2
5 | num_batches: 10
6 | num_benchmarks: 1
7 | num_profiles: 1
8 | sharding_type: table_wise
9 | profile_dir: "."
10 | name: "sparse_data_dist_base"
11 | # export_stacks: True # enable this to export stack traces
12 | PipelineConfig:
13 | pipeline: "sparse"
14 | ModelInputConfig:
15 | feature_pooling_avg: 30
16 | EmbeddingTablesConfig:
17 | num_unweighted_features: 90
18 | num_weighted_features: 80
19 | embedding_feature_dim: 256
20 | additional_tables:
21 | - - name: FP16_table
22 | embedding_dim: 512
23 | num_embeddings: 100_000
24 | feature_names: ["additional_0_0"]
25 | data_type: FP16
26 | - name: large_table
27 | embedding_dim: 2048
28 | num_embeddings: 1_000_000
29 | feature_names: ["additional_0_1"]
30 | - []
31 | - - name: skipped_table
32 | embedding_dim: 128
33 | num_embeddings: 100_000
34 | feature_names: ["additional_2_1"]
35 | PlannerConfig:
36 | additional_constraints:
37 | large_table:
38 | sharding_types: [column_wise]
39 |
--------------------------------------------------------------------------------
/torchrec/inference/inference_legacy/include/torchrec/inference/ExceptionHandler.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | * All rights reserved.
4 | *
5 | * This source code is licensed under the BSD-style license found in the
6 | * LICENSE file in the root directory of this source tree.
7 | */
8 |
9 | #pragma once
10 |
11 | #include
12 | #include
13 |
14 | #include
15 |
16 | #include "torchrec/inference/Exception.h"
17 | #include "torchrec/inference/Types.h"
18 |
19 | namespace torchrec {
20 | template
21 | void handleRequestException(
22 | folly::Promise>& promise,
23 | const std::string& msg) {
24 | auto ex = folly::make_exception_wrapper(msg);
25 | auto response = std::make_unique();
26 | response->exception = std::move(ex);
27 | promise.setValue(std::move(response));
28 | }
29 |
30 | template
31 | void handleBatchException(
32 | std::vector& contexts,
33 | const std::string& msg) {
34 | for (auto& context : contexts) {
35 | handleRequestException(context.promise, msg);
36 | }
37 | }
38 |
39 | } // namespace torchrec
40 |
--------------------------------------------------------------------------------
/.github/workflows/validate-binaries.yml:
--------------------------------------------------------------------------------
1 | name: Validate binaries
2 |
3 | on:
4 | workflow_call:
5 | inputs:
6 | channel:
7 | description: "Channel to use (nightly, release)"
8 | required: false
9 | type: string
10 | default: release
11 | ref:
12 | description: 'Reference to checkout, defaults to empty'
13 | default: ""
14 | required: false
15 | type: string
16 | workflow_dispatch:
17 | inputs:
18 | channel:
19 | description: "Channel to use (nightly, release, test, pypi)"
20 | required: true
21 | type: choice
22 | options:
23 | - release
24 | - nightly
25 | - test
26 | ref:
27 | description: 'Reference to checkout, defaults to empty'
28 | default: ""
29 | required: false
30 | type: string
31 |
32 | jobs:
33 | validate-binaries:
34 | uses: pytorch/test-infra/.github/workflows/validate-domain-library.yml@main
35 | with:
36 | package_type: "wheel"
37 | os: "linux"
38 | channel: ${{ inputs.channel }}
39 | repository: "meta-pytorch/torchrec"
40 | smoke_test: "source ./.github/scripts/validate_binaries.sh"
41 | with_cuda: enable
42 | with_rocm: false
43 |
--------------------------------------------------------------------------------
/torchrec/csrc/dynamic_embedding/details/lxu_strategy.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | * All rights reserved.
4 | *
5 | * This source code is licensed under the BSD-style license found in the
6 | * LICENSE file in the root directory of this source tree.
7 | */
8 |
9 | #pragma once
10 | #include
11 | #include
12 |
13 | namespace torchrec {
14 |
15 | class LXUStrategy {
16 | public:
17 | LXUStrategy() = default;
18 | LXUStrategy(const LXUStrategy&) = delete;
19 | LXUStrategy(LXUStrategy&& o) noexcept = default;
20 |
21 | virtual void update_time(uint32_t time) = 0;
22 | virtual int64_t time(lxu_record_t record) = 0;
23 |
24 | virtual lxu_record_t update(
25 | int64_t global_id,
26 | int64_t cache_id,
27 | std::optional val) = 0;
28 |
29 | /**
30 | * Analysis all ids and returns the num_elems that are most need to evict.
31 | * @param iterator Returns each global_id to ExtValue pair. Returns nullopt
32 | * when at ends.
33 | * @param num_to_evict
34 | * @return
35 | */
36 | virtual std::vector evict(
37 | iterator_t iterator,
38 | uint64_t num_to_evict) = 0;
39 | };
40 |
41 | } // namespace torchrec
42 |
--------------------------------------------------------------------------------
/torchrec/distributed/benchmark/yaml/base_pipeline_light.yml:
--------------------------------------------------------------------------------
1 | # this is a very basic sparse data dist config
2 | # runs on 2 ranks, showing traces with reasonable workloads
3 | RunOptions:
4 | world_size: 2
5 | num_batches: 10
6 | num_benchmarks: 1
7 | num_profiles: 1
8 | sharding_type: table_wise
9 | profile_dir: "."
10 | name: "base_pipeline_light"
11 | # export_stacks: True # enable this to export stack traces
12 | loglevel: "info"
13 | PipelineConfig:
14 | pipeline: "base"
15 | ModelInputConfig:
16 | feature_pooling_avg: 10
17 | EmbeddingTablesConfig:
18 | num_unweighted_features: 20
19 | num_weighted_features: 20
20 | embedding_feature_dim: 256
21 | additional_tables:
22 | - - name: FP16_table
23 | embedding_dim: 512
24 | num_embeddings: 100_000
25 | feature_names: ["additional_0_0"]
26 | data_type: FP16
27 | - name: large_table
28 | embedding_dim: 2048
29 | num_embeddings: 1_000_000
30 | feature_names: ["additional_0_1"]
31 | - []
32 | - - name: skipped_table
33 | embedding_dim: 128
34 | num_embeddings: 100_000
35 | feature_names: ["additional_2_1"]
36 | PlannerConfig:
37 | additional_constraints:
38 | large_table:
39 | sharding_types: [column_wise]
40 |
--------------------------------------------------------------------------------
/torchrec/csrc/dynamic_embedding/details/bits_op.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | * All rights reserved.
4 | *
5 | * This source code is licensed under the BSD-style license found in the
6 | * LICENSE file in the root directory of this source tree.
7 | */
8 |
9 | #pragma once
10 | #include
11 | #include
12 |
13 | namespace torchrec {
14 |
15 | namespace bits_impl {
16 | template
17 | struct Clz {
18 | int operator()(T v) const;
19 | };
20 |
21 | template
22 | struct Ctz {
23 | int operator()(T v) const;
24 | };
25 | } // namespace bits_impl
26 |
27 | /**
28 | * Returns the number of leading 0-bits in t, starting at the most significant
29 | * bit position. If t is 0, the result is undefined.
30 | * clz stands for counting leading zeros.
31 | */
32 | template
33 | inline int clz(T t) {
34 | bits_impl::Clz clz;
35 | return clz(t);
36 | }
37 |
38 | /**
39 | * Returns the number of trailing 0-bits in t, starting at the least significant
40 | * bit position. If t is 0, the result is undefined.
41 | * ctz stands for counting trailing zeros.
42 | */
43 | template
44 | inline int ctz(T t) {
45 | bits_impl::Ctz ctz;
46 | return ctz(t);
47 | }
48 |
49 | } // namespace torchrec
50 |
--------------------------------------------------------------------------------
/.github/scripts/install_libs.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | echo "CU_VERSION: ${CU_VERSION}"
9 | echo "CHANNEL: ${CHANNEL}"
10 | echo "CONDA_ENV: ${CONDA_ENV}"
11 |
12 | if [[ $CU_VERSION = cu* ]]; then
13 | # Setting LD_LIBRARY_PATH fixes the runtime error with fbgemm_gpu not
14 | # being able to locate libnvrtc.so
15 | echo "[NOVA] Setting LD_LIBRARY_PATH ..."
16 | conda env config vars set -p ${CONDA_ENV} \
17 | LD_LIBRARY_PATH="/usr/local/lib:${CUDA_HOME}/lib64:${CONDA_ENV}/lib:${LD_LIBRARY_PATH}"
18 | else
19 | echo "[NOVA] Setting LD_LIBRARY_PATH ..."
20 | conda env config vars set -p ${CONDA_ENV} \
21 | LD_LIBRARY_PATH="/usr/local/lib:${CONDA_ENV}/lib:${LD_LIBRARY_PATH}"
22 | fi
23 |
24 | if [ "$CHANNEL" = "nightly" ]; then
25 | ${CONDA_RUN} pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/"$CU_VERSION"
26 | elif [ "$CHANNEL" = "test" ]; then
27 | ${CONDA_RUN} pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/test/"$CU_VERSION"
28 | fi
29 |
30 |
31 | ${CONDA_RUN} pip install -r requirements.txt
32 |
--------------------------------------------------------------------------------
/contrib/dynamic_embedding/src/tde/details/cacheline_id_transformer_benchmark.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include "benchmark/benchmark.h"
3 | #include "tde/details/cacheline_id_transformer.h"
4 |
5 | namespace tde::details {
6 |
7 | static void BM_CachelineIDTransformer(benchmark::State& state) {
8 | CachelineIDTransformer transformer(2e8);
9 | torch::Tensor global_ids = torch::empty({1024, 1024}, torch::kLong);
10 | torch::Tensor cache_ids = torch::empty_like(global_ids);
11 | for (auto _ : state) {
12 | state.PauseTiming();
13 | global_ids.random_(state.range(0), state.range(1));
14 | state.ResumeTiming();
15 | transformer.Transform(
16 | tcb::span{
17 | global_ids.template data_ptr(),
18 | static_cast(global_ids.numel())},
19 | tcb::span{
20 | cache_ids.template data_ptr(),
21 | static_cast(cache_ids.numel())});
22 | }
23 | }
24 |
25 | BENCHMARK(BM_CachelineIDTransformer)
26 | ->Iterations(100)
27 | ->Unit(benchmark::kMillisecond)
28 | ->ArgNames({"rand_from", "rand_to"})
29 | ->Args({static_cast(1e10), static_cast(2e10)})
30 | ->Args({static_cast(1e6), static_cast(2e6)});
31 |
32 | } // namespace tde::details
33 |
--------------------------------------------------------------------------------
/contrib/dynamic_embedding/src/tde/details/naive_id_transformer_benchmark.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include "benchmark/benchmark.h"
3 | #include "tde/details/naive_id_transformer.h"
4 |
5 | namespace tde::details {
6 |
7 | static void BM_NaiveIDTransformer(benchmark::State& state) {
8 | using Tag = int32_t;
9 | NaiveIDTransformer transformer(2e8);
10 | torch::Tensor global_ids = torch::empty({1024, 1024}, torch::kLong);
11 | torch::Tensor cache_ids = torch::empty_like(global_ids);
12 | for (auto _ : state) {
13 | state.PauseTiming();
14 | global_ids.random_(state.range(0), state.range(1));
15 | state.ResumeTiming();
16 | transformer.Transform(
17 | tcb::span{
18 | global_ids.template data_ptr(),
19 | static_cast(global_ids.numel())},
20 | tcb::span{
21 | cache_ids.template data_ptr(),
22 | static_cast(cache_ids.numel())});
23 | }
24 | }
25 |
26 | BENCHMARK(BM_NaiveIDTransformer)
27 | ->Iterations(100)
28 | ->Unit(benchmark::kMillisecond)
29 | ->ArgNames({"rand_from", "rand_to"})
30 | ->Args({static_cast(1e10), static_cast(2e10)})
31 | ->Args({static_cast(1e6), static_cast(2e6)});
32 |
33 | } // namespace tde::details
34 |
--------------------------------------------------------------------------------
/torchrec/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 |
10 | """Torchrec Datasets
11 |
12 | Torchrec contains two popular recys datasets, the `Kaggle/Criteo Display Advertising `_ Dataset
13 | and the `MovieLens 20M `_ Dataset.
14 |
15 | Additionally, it contains a RandomDataset, which is useful to generate random data in the same format as the above.
16 |
17 | Lastly, it contains scripts and utilities for pre-processing, loading, etc.
18 |
19 | Example::
20 |
21 | from torchrec.datasets.criteo import criteo_kaggle
22 | datapipe = criteo_terabyte(
23 | ("/home/datasets/criteo/day_0.tsv", "/home/datasets/criteo/day_1.tsv")
24 | )
25 | datapipe = dp.iter.Batcher(datapipe, 100)
26 | datapipe = dp.iter.Collator(datapipe)
27 | batch = next(iter(datapipe))
28 | """
29 |
30 | import torchrec.datasets.criteo # noqa
31 | import torchrec.datasets.movielens # noqa
32 | import torchrec.datasets.random # noqa
33 | import torchrec.datasets.utils # noqa
34 |
--------------------------------------------------------------------------------
/torchrec/quant/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 |
10 | """Torchrec Quantization
11 |
12 | Torchrec provides a quantized version of EmbeddingBagCollection for inference.
13 | It relies on fbgemm quantized ops.
14 | This reduces the size of the model weights and speeds up model execution.
15 |
16 | Example:
17 | >>> import torch.quantization as quant
18 | >>> import torchrec.quant as trec_quant
19 | >>> import torchrec as trec
20 | >>> qconfig = quant.QConfig(
21 | >>> activation=quant.PlaceholderObserver,
22 | >>> weight=quant.PlaceholderObserver.with_args(dtype=torch.qint8),
23 | >>> )
24 | >>> quantized = quant.quantize_dynamic(
25 | >>> module,
26 | >>> qconfig_spec={
27 | >>> trec.EmbeddingBagCollection: qconfig,
28 | >>> },
29 | >>> mapping={
30 | >>> trec.EmbeddingBagCollection: trec_quant.EmbeddingBagCollection,
31 | >>> },
32 | >>> inplace=inplace,
33 | >>> )
34 | """
35 |
36 | from torchrec.quant.embedding_modules import EmbeddingBagCollection # noqa
37 |
--------------------------------------------------------------------------------
/torchrec/distributed/benchmark/yaml/sparse_data_dist_base_vbe.yml:
--------------------------------------------------------------------------------
1 | # this is a very basic sparse data dist config
2 | # runs on 2 ranks, showing traces with reasonable workloads
3 | RunOptions:
4 | world_size: 2
5 | batch_size: 16384
6 | num_batches: 10
7 | num_benchmarks: 1
8 | num_profiles: 1
9 | sharding_type: table_wise
10 | profile_dir: "."
11 | name: "sparse_data_dist_base"
12 | # export_stacks: True # enable this to export stack traces
13 | PipelineConfig:
14 | pipeline: "sparse"
15 | ModelInputConfig:
16 | feature_pooling_avg: 30
17 | use_variable_batch: True
18 | EmbeddingTablesConfig:
19 | num_unweighted_features: 90
20 | num_weighted_features: 80
21 | embedding_feature_dim: 256
22 | additional_tables:
23 | - - name: FP16_table
24 | embedding_dim: 512
25 | num_embeddings: 100_000
26 | feature_names: ["additional_0_0"]
27 | data_type: FP16
28 | - name: large_table
29 | embedding_dim: 2048
30 | num_embeddings: 1_000_000
31 | feature_names: ["additional_0_1"]
32 | - []
33 | - - name: skipped_table
34 | embedding_dim: 128
35 | num_embeddings: 100_000
36 | feature_names: ["additional_2_1"]
37 | PlannerConfig:
38 | additional_constraints:
39 | large_table:
40 | sharding_types: [column_wise]
41 |
--------------------------------------------------------------------------------
/torchrec/modules/tests/test_code_quality.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 |
10 | import inspect
11 | import sys
12 | import unittest
13 |
14 | import torch
15 | import torchrec # noqa
16 | from torchrec.linter.module_linter import MAX_NUM_ARGS_IN_MODULE_CTOR
17 |
18 |
19 | class CodeQualityTest(unittest.TestCase):
20 | def test_num_ctor_args(self) -> None:
21 | classes = inspect.getmembers(sys.modules["torchrec"], inspect.isclass)
22 | for class_name, clazz in classes:
23 | if issubclass(clazz, torch.nn.Module):
24 | num_args_excluding_self = (
25 | len(inspect.getfullargspec(clazz.__init__).args) - 1
26 | )
27 | self.assertLessEqual(
28 | num_args_excluding_self,
29 | MAX_NUM_ARGS_IN_MODULE_CTOR,
30 | "Modules in TorchRec can have no more than {} constructor args, but {} has {}.".format(
31 | MAX_NUM_ARGS_IN_MODULE_CTOR, class_name, num_args_excluding_self
32 | ),
33 | )
34 |
--------------------------------------------------------------------------------
/torchrec/sparse/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 |
10 | """Torchrec Jagged Tensors
11 |
12 | It has 3 classes: JaggedTensor, KeyedJaggedTensor, KeyedTensor.
13 |
14 | JaggedTensor
15 |
16 | It represents an (optionally weighted) jagged tensor. A JaggedTensor is a
17 | tensor with a jagged dimension which is dimension whose slices may be of
18 | different lengths. See KeyedJaggedTensor docstring for full example and further
19 | information.
20 |
21 | KeyedJaggedTensor
22 |
23 | KeyedJaggedTensor has additional "Key" information. Keyed on first dimesion,
24 | and jagged on last dimension. Please refer to KeyedJaggedTensor docstring for full example and
25 | further information.
26 |
27 | KeyedTensor
28 |
29 | KeyedTensor holds a concatenated list of dense tensors each of which can be accessed by a key.
30 | Keyed dimension can be variable length (length_per_key). Common use cases uses include storage
31 | of pooled embeddings of different dimensions. Please refer to KeyedTensor docstring for full
32 | example and further information.
33 | """
34 |
35 | from . import jagged_tensor # noqa
36 |
--------------------------------------------------------------------------------
/torchrec/csrc/dynamic_embedding/details/io_parameter.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | * All rights reserved.
4 | *
5 | * This source code is licensed under the BSD-style license found in the
6 | * LICENSE file in the root directory of this source tree.
7 | */
8 |
9 | #include
10 |
11 | namespace torchrec {
12 |
13 | using GlobalIDFetchCallback = void (*)(
14 | void* ctx,
15 | uint32_t offset,
16 | uint32_t optimizer_state,
17 | void* data,
18 | uint32_t data_len);
19 |
20 | struct IOFetchParameter {
21 | const char* table_name;
22 | uint32_t num_cols;
23 | uint32_t num_global_ids;
24 | const int64_t* col_ids;
25 | const int64_t* global_ids;
26 | uint32_t num_optimizer_states;
27 | void* on_complete_context;
28 | GlobalIDFetchCallback on_global_id_fetched;
29 | void (*on_all_fetched)(void* ctx);
30 | };
31 |
32 | struct IOPushParameter {
33 | const char* table_name;
34 | uint32_t num_cols;
35 | uint32_t num_global_ids;
36 | const int64_t* col_ids;
37 | const int64_t* global_ids;
38 | uint32_t num_optimizer_states;
39 | const uint32_t* optimizer_state_ids;
40 | uint32_t num_offsets;
41 | const uint64_t* offsets;
42 | const void* data;
43 | void* on_complete_context;
44 | void (*on_push_complete)(void* ctx);
45 | };
46 |
47 | } // namespace torchrec
48 |
--------------------------------------------------------------------------------
/torchrec/modules/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 |
10 | """Torchrec Common Modules
11 |
12 | The torchrec modules contain a collection of various modules.
13 |
14 | These modules include:
15 | - extensions of `nn.Embedding` and `nn.EmbeddingBag`, called `EmbeddingBagCollection`
16 | and `EmbeddingCollection` respectively.
17 | - established modules such as `DeepFM `_ and
18 | `CrossNet `_.
19 | - common module patterns such as `MLP` and `SwishLayerNorm`.
20 | - custom modules for TorchRec such as `PositionWeightedModule` and
21 | `LazyModuleExtensionMixin`.
22 | - `EmbeddingTower` and `EmbeddingTowerCollection`, logical "tower" of embeddings
23 | passed to provided interaction module.
24 | """
25 |
26 | from . import ( # noqa # noqa # noqa # noqa # noqa # noqa # noqa # noqa # noqa
27 | activation,
28 | crossnet,
29 | deepfm,
30 | embedding_configs,
31 | embedding_modules,
32 | embedding_tower,
33 | feature_processor,
34 | lazy_extension,
35 | mlp,
36 | )
37 |
--------------------------------------------------------------------------------
/examples/bert4rec/models/tests/test_bert4rec.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 |
10 | #!/usr/bin/env python3
11 |
12 | import unittest
13 |
14 | import torch
15 | from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
16 |
17 | from ..bert4rec import BERT4Rec
18 |
19 |
20 | class BERT4RecTest(unittest.TestCase):
21 | def test_bert4rec(self) -> None:
22 | # input tensor
23 | # [2, 4],
24 | # [3, 4, 5],
25 | input_kjt = KeyedJaggedTensor.from_lengths_sync(
26 | keys=["item"],
27 | values=torch.tensor([2, 4, 3, 4, 5]),
28 | lengths=torch.tensor([2, 3]),
29 | )
30 | device = (
31 | torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
32 | )
33 | input_kjt = input_kjt.to(device)
34 | bert4rec = BERT4Rec(
35 | vocab_size=6, max_len=3, emb_dim=4, nhead=4, num_layers=4, device=device
36 | )
37 | logits = bert4rec(input_kjt)
38 | assert logits.size() == torch.Size(
39 | [input_kjt.stride(), bert4rec.max_len, bert4rec.vocab_size]
40 | )
41 |
--------------------------------------------------------------------------------
/torchrec/inference/inference_legacy/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-ignore-all-errors[0, 21]
9 |
10 | """Torchrec Inference
11 |
12 | Torchrec inference provides a Torch.Deploy based library for GPU inference.
13 |
14 | These includes:
15 | - Model packaging in Python
16 | - `PredictModule` and `PredictFactory` are the contracts between the Python model authoring and the C++ model serving.
17 | - `PredictFactoryPackager` can be used to package a PredictFactory class using torch.package.
18 | - Model serving in C++
19 | - `BatchingQueue` is a generalized config-based request tensor batching implementation.
20 | - `GPUExecutor` handles the forward call into the inference model inside Torch.Deploy.
21 |
22 | We implemented an example of how to use this library with the TorchRec DLRM model.
23 | - `examples/dlrm/inference/dlrm_packager.py`: this demonstrates how to export the DLRM model as a torch.package.
24 | - `examples/dlrm/inference/dlrm_predict.py`: this shows how to use `PredictModule` and `PredictFactory` based on an existing model.
25 | """
26 |
27 | from . import model_packager # noqa
28 |
--------------------------------------------------------------------------------
/torchrec/inference/protos/predictor.proto:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | * All rights reserved.
4 | *
5 | * This source code is licensed under the BSD-style license found in the
6 | * LICENSE file in the root directory of this source tree.
7 | */
8 |
9 | syntax = "proto3";
10 |
11 | package predictor;
12 |
13 | message SparseFeatures {
14 | int32 num_features = 1;
15 | // int32: T x B
16 | bytes lengths = 2;
17 | // T x B x L (jagged)
18 | bytes values = 3;
19 | bytes weights = 4;
20 | }
21 |
22 | message FloatFeatures {
23 | int32 num_features = 1;
24 | // shape: {B}
25 | bytes values = 2;
26 | }
27 |
28 | message PredictionRequest {
29 | int32 batch_size = 1;
30 | FloatFeatures float_features = 2;
31 | SparseFeatures id_list_features = 3;
32 | SparseFeatures id_score_list_features = 4;
33 | FloatFeatures embedding_features = 5;
34 | SparseFeatures unary_features = 6;
35 | }
36 |
37 | message FloatVec {
38 | repeated float data = 1;
39 | }
40 |
41 | // TODO: See whether FloatVec can be replaced with folly::iobuf
42 | message PredictionResponse {
43 | // Task name to prediction Tensor
44 | map predictions = 1;
45 | }
46 |
47 | // The predictor service definition. Synchronous for now.
48 | service Predictor {
49 | rpc Predict(PredictionRequest) returns (PredictionResponse) {}
50 | }
51 |
--------------------------------------------------------------------------------
/benchmarks/cpp/dynamic_embedding/random_bits_generator_benchmark.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | * All rights reserved.
4 | *
5 | * This source code is licensed under the BSD-style license found in the
6 | * LICENSE file in the root directory of this source tree.
7 | */
8 |
9 | #include
10 | #include
11 | #include
12 |
13 | namespace torchrec {
14 |
15 | void BMRandomBitsGenerator(benchmark::State& state) {
16 | auto n = state.range(0);
17 | auto n_bits_limit = state.range(1);
18 | BitScanner scanner(n);
19 | std::mt19937_64 engine((std::random_device())());
20 | std::uniform_int_distribution dist(1, n_bits_limit);
21 | uint16_t n_bits = dist(engine);
22 | for (auto _ : state) {
23 | if (n_bits != 0) {
24 | scanner.reset_array([&](std::span span) {
25 | for (auto& v : span) {
26 | v = engine();
27 | }
28 | });
29 | } else {
30 | n_bits = dist(engine);
31 | }
32 | benchmark::DoNotOptimize(scanner.is_next_n_bits_all_zero(n_bits));
33 | }
34 | }
35 |
36 | BENCHMARK(BMRandomBitsGenerator)
37 | ->ArgNames({"n", "limit"})
38 | ->Args({1024, 32})
39 | ->Unit(benchmark::kMillisecond)
40 | ->Iterations(1024 * 1024);
41 |
42 | } // namespace torchrec
43 |
--------------------------------------------------------------------------------
/torchrec/datasets/scripts/nvt/utils/dask.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | import os
9 | import shutil
10 |
11 | import numba
12 | from dask.distributed import Client
13 | from dask_cuda import LocalCUDACluster
14 | from nvtabular.utils import device_mem_size
15 |
16 |
17 | def setup_dask(dask_workdir):
18 | if os.path.exists(dask_workdir):
19 | shutil.rmtree(dask_workdir)
20 | os.makedirs(dask_workdir)
21 |
22 | device_limit_frac = 0.8 # Spill GPU-Worker memory to host at this limit.
23 | device_pool_frac = 0.7
24 |
25 | # Use total device size to calculate device limit and pool_size
26 | device_size = device_mem_size(kind="total")
27 | device_limit = int(device_limit_frac * device_size)
28 | device_pool_size = int(device_pool_frac * device_size)
29 |
30 | cluster = LocalCUDACluster(
31 | protocol="tcp",
32 | n_workers=len(numba.cuda.gpus),
33 | CUDA_VISIBLE_DEVICES=range(len(numba.cuda.gpus)),
34 | device_memory_limit=device_limit,
35 | local_directory=dask_workdir,
36 | rmm_pool_size=(device_pool_size // 256) * 256,
37 | )
38 |
39 | return Client(cluster)
40 |
--------------------------------------------------------------------------------
/torchrec/modules/pruning_logger.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | import logging
9 | from abc import ABC, abstractmethod
10 | from contextlib import contextmanager
11 | from dataclasses import dataclass
12 | from types import SimpleNamespace
13 | from typing import Generator, Optional
14 |
15 | logger: logging.Logger = logging.getLogger(__name__)
16 |
17 |
18 | @dataclass
19 | class PruningLogBase(object):
20 | pass
21 |
22 |
23 | class PruningLogger(ABC):
24 | @classmethod
25 | @abstractmethod
26 | @contextmanager
27 | def pruning_logger(
28 | cls,
29 | event: str,
30 | trainer: Optional[str] = None,
31 | publisher: Optional[str] = None,
32 | ) -> Generator[object, None, None]:
33 | pass
34 |
35 |
36 | class PruningLoggerDefault(PruningLogger):
37 | """
38 | noop logger as a default
39 | """
40 |
41 | @classmethod
42 | @contextmanager
43 | def pruning_logger(
44 | cls,
45 | event: str,
46 | trainer: Optional[str] = None,
47 | publisher: Optional[str] = None,
48 | ) -> Generator[object, None, None]:
49 | yield SimpleNamespace()
50 |
--------------------------------------------------------------------------------
/docs/source/planner-api-reference.rst:
--------------------------------------------------------------------------------
1 | Planner
2 | ----------------------------------
3 |
4 | The TorchRec Planner is responsible for determining the most performant, balanced
5 | sharding plan for distributed training and inference.
6 |
7 | The main API for generating a sharding plan is ``EmbeddingShardingPlanner.plan``
8 |
9 | .. automodule:: torchrec.distributed.types
10 |
11 | .. autoclass:: ShardingPlan
12 | :members:
13 |
14 | .. automodule:: torchrec.distributed.planner.planners
15 |
16 | .. autoclass:: EmbeddingShardingPlanner
17 | :members:
18 |
19 | .. automodule:: torchrec.distributed.planner.enumerators
20 |
21 | .. autoclass:: EmbeddingEnumerator
22 | :members:
23 |
24 | .. automodule:: torchrec.distributed.planner.partitioners
25 |
26 | .. autoclass:: GreedyPerfPartitioner
27 | :members:
28 |
29 |
30 | .. automodule:: torchrec.distributed.planner.storage_reservations
31 |
32 | .. autoclass:: HeuristicalStorageReservation
33 | :members:
34 |
35 | .. automodule:: torchrec.distributed.planner.proposers
36 |
37 | .. autoclass:: GreedyProposer
38 | :members:
39 |
40 |
41 | .. automodule:: torchrec.distributed.planner.shard_estimators
42 |
43 | .. autoclass:: EmbeddingPerfEstimator
44 | :members:
45 |
46 |
47 | .. automodule:: torchrec.distributed.planner.shard_estimators
48 |
49 | .. autoclass:: EmbeddingStorageEstimator
50 | :members:
51 |
--------------------------------------------------------------------------------
/torchrec/inference/inference_legacy/protos/predictor.proto:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | * All rights reserved.
4 | *
5 | * This source code is licensed under the BSD-style license found in the
6 | * LICENSE file in the root directory of this source tree.
7 | */
8 |
9 | syntax = "proto3";
10 |
11 | package predictor;
12 |
13 | message SparseFeatures {
14 | int32 num_features = 1;
15 | // int32: T x B
16 | bytes lengths = 2;
17 | // T x B x L (jagged)
18 | bytes values = 3;
19 | bytes weights = 4;
20 | }
21 |
22 | message FloatFeatures {
23 | int32 num_features = 1;
24 | // shape: {B}
25 | bytes values = 2;
26 | }
27 |
28 | message PredictionRequest {
29 | int32 batch_size = 1;
30 | FloatFeatures float_features = 2;
31 | SparseFeatures id_list_features = 3;
32 | SparseFeatures id_score_list_features = 4;
33 | FloatFeatures embedding_features = 5;
34 | SparseFeatures unary_features = 6;
35 | }
36 |
37 | message FloatVec {
38 | repeated float data = 1;
39 | }
40 |
41 | // TODO: See whether FloatVec can be replaced with folly::iobuf
42 | message PredictionResponse {
43 | // Task name to prediction Tensor
44 | map predictions = 1;
45 | }
46 |
47 | // The predictor service definition. Synchronous for now.
48 | service Predictor {
49 | rpc Predict(PredictionRequest) returns (PredictionResponse) {}
50 | }
51 |
--------------------------------------------------------------------------------
/contrib/dynamic_embedding/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.11.0 FATAL_ERROR)
2 |
3 | project(torchrec_dynamic_embedding
4 | VERSION 0.0.1
5 | LANGUAGES CXX C)
6 |
7 | option(TDE_TORCH_BASE_DIR "torch python directory" "")
8 |
9 | if (NOT TDE_TORCH_BASE_DIR)
10 | message(FATAL_ERROR "TDE_TORCH_BASE_DIR must set."
11 | "Use python -c 'import torch;import os.path;print(os.path.dirname(torch.__file__))'"
12 | " to get this dir.")
13 | else()
14 | find_package(Torch REQUIRED
15 | PATHS "${TDE_TORCH_BASE_DIR}"
16 | NO_DEFAULT_PATH)
17 | endif()
18 |
19 |
20 | if(${CMAKE_SOURCE_DIR} STREQUAL ${CMAKE_CURRENT_SOURCE_DIR})
21 | set(TDE_IS_TOP_LEVEL_PROJECT ON)
22 | else()
23 | set(TDE_IS_TOP_LEVEL_PROJECT OFF)
24 | endif()
25 |
26 | option(TDE_WITH_TESTING "Enable unittest in C++ side" ${TDE_IS_TOP_LEVEL_PROJECT})
27 |
28 | if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
29 | option(TDE_WITH_CXX11_ABI "GLIBCXX use c++11 ABI or not. libtorch installed by conda is not use it by default" OFF)
30 | if (TDE_WITH_CXX11_ABI)
31 | add_definitions("-D_GLIBCXX_USE_CXX11_ABI=1")
32 | else()
33 | add_definitions("-D_GLIBCXX_USE_CXX11_ABI=0")
34 | endif()
35 | endif()
36 |
37 | if (TDE_WITH_TESTING)
38 | enable_testing()
39 | add_subdirectory(tests/memory_io)
40 | endif()
41 | add_subdirectory(src)
42 |
--------------------------------------------------------------------------------
/torchrec/csrc/dynamic_embedding/details/io_registry.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | * All rights reserved.
4 | *
5 | * This source code is licensed under the BSD-style license found in the
6 | * LICENSE file in the root directory of this source tree.
7 | */
8 |
9 | #pragma once
10 | #include
11 | #include
12 | #include
13 | #include
14 | #include
15 | #include
16 | #include
17 |
18 | namespace torchrec {
19 |
20 | struct IOProvider {
21 | const char* type;
22 | void* (*initialize)(const char* cfg);
23 | void (*fetch)(void* instance, IOFetchParameter cfg);
24 | void (*push)(void* instance, IOPushParameter cfg);
25 | void (*finalize)(void*);
26 | };
27 |
28 | class IORegistry {
29 | public:
30 | void register_provider(IOProvider provider);
31 | void register_plugin(const char* filename);
32 | [[nodiscard]] IOProvider resolve(const std::string& name) const;
33 |
34 | static IORegistry& Instance();
35 |
36 | private:
37 | IORegistry() = default;
38 | ska::flat_hash_map providers_;
39 | struct DLCloser {
40 | void operator()(void* ptr) const;
41 | };
42 |
43 | using DLPtr = std::unique_ptr;
44 | std::vector dls_;
45 | };
46 |
47 | } // namespace torchrec
48 |
--------------------------------------------------------------------------------
/torchrec/csrc/dynamic_embedding/details/bitmap.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | * All rights reserved.
4 | *
5 | * This source code is licensed under the BSD-style license found in the
6 | * LICENSE file in the root directory of this source tree.
7 | */
8 |
9 | #pragma once
10 | #include
11 | #include
12 |
13 | namespace torchrec {
14 |
15 | /**
16 | * Bitmap
17 | *
18 | * A bitmap for recording whether num_bits of slots are
19 | * occupied or free.
20 | */
21 | template
22 | struct Bitmap {
23 | explicit Bitmap(int64_t num_bits);
24 | Bitmap(const Bitmap&) = delete;
25 | Bitmap(Bitmap&&) noexcept = default;
26 |
27 | /**
28 | * Returns the position of the next free slot.
29 | * If the bitmap is full, return `num_total_bits_`.
30 | */
31 | int64_t next_free_bit();
32 |
33 | /**
34 | * Set the slot of position `offset` to free.
35 | */
36 | void free_bit(int64_t offset);
37 |
38 | /**
39 | * Returns if all slots in the bitmap is occupied.
40 | */
41 | bool full() const;
42 |
43 | static constexpr int64_t num_bits_per_value = sizeof(T) * 8;
44 |
45 | const int64_t num_total_bits_;
46 | const int64_t num_values_;
47 | std::unique_ptr values_;
48 |
49 | int64_t next_free_bit_;
50 | };
51 |
52 | } // namespace torchrec
53 |
54 | #include
55 |
--------------------------------------------------------------------------------
/torchrec/ir/types.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 |
10 | #!/usr/bin/env python3
11 |
12 | import abc
13 | from typing import Any, Dict, List, Optional
14 |
15 | import torch
16 |
17 | from torch import nn
18 |
19 |
20 | class SerializerInterface(abc.ABC):
21 | """
22 | Interface for Serializer classes for torch.export IR.
23 | """
24 |
25 | @classmethod
26 | @property
27 | def module_to_serializer_cls(cls) -> Dict[str, Any]:
28 | raise NotImplementedError
29 |
30 | @classmethod
31 | @abc.abstractmethod
32 | def encapsulate_module(cls, module: nn.Module) -> List[str]:
33 | # Take the eager embedding module and encapsulate the module, including serialization
34 | # and meta_forward-swapping, then returns a list of children (fqns) which needs further encapsulation
35 | raise NotImplementedError
36 |
37 | @classmethod
38 | @abc.abstractmethod
39 | def decapsulate_module(
40 | cls, module: nn.Module, device: Optional[torch.device] = None
41 | ) -> nn.Module:
42 | # Take the eager embedding module and decapsulate it by removing serialization and meta_forward-swapping
43 | raise NotImplementedError
44 |
--------------------------------------------------------------------------------
/contrib/dynamic_embedding/src/tde/details/mixed_lfu_lru_strategy_benchmark.cpp:
--------------------------------------------------------------------------------
1 | #include "benchmark/benchmark.h"
2 | #include "tde/details/mixed_lfu_lru_strategy.h"
3 |
4 | namespace tde::details {
5 | void BM_MixedLFULRUStrategy(benchmark::State& state) {
6 | size_t num_ext_values = state.range(0);
7 | std::vector ext_values(num_ext_values);
8 |
9 | MixedLFULRUStrategy strategy;
10 | for (auto& v : ext_values) {
11 | v = strategy.Update(0, 0, std::nullopt);
12 | }
13 |
14 | size_t num_elems = state.range(1);
15 | std::default_random_engine engine((std::random_device())());
16 | size_t time = 0;
17 | for (auto _ : state) {
18 | state.PauseTiming();
19 | std::vector offsets;
20 | offsets.reserve(num_elems);
21 | for (size_t i = 0; i < num_elems; ++i) {
22 | std::uniform_int_distribution dist(0, num_elems - 1);
23 | offsets.emplace_back(dist(engine));
24 | }
25 | state.ResumeTiming();
26 |
27 | ++time;
28 | strategy.UpdateTime(time);
29 | for (auto& v : offsets) {
30 | ext_values[v] = strategy.Update(0, 0, ext_values[v]);
31 | }
32 | }
33 | }
34 |
35 | BENCHMARK(BM_MixedLFULRUStrategy)
36 | ->ArgNames({"num_ext_values", "num_elems_per_iter"})
37 | ->Args({30000000, 1024 * 1024})
38 | ->Args({300000000, 1024 * 1024})
39 | ->Unit(benchmark::kMillisecond)
40 | ->Iterations(100);
41 |
42 | } // namespace tde::details
43 |
--------------------------------------------------------------------------------
/torchrec/datasets/scripts/nvt/utils/criteo_constant.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # This file is needed because NvTabulr will create so many ghots processes if we
9 | # import from TorchRec criteo datasets
10 |
11 | from typing import List
12 |
13 |
14 | FREQUENCY_THRESHOLD = 3
15 | INT_FEATURE_COUNT = 13
16 | CAT_FEATURE_COUNT = 26
17 | DAYS = 24
18 | DEFAULT_LABEL_NAME = "label"
19 | DEFAULT_INT_NAMES: List[str] = [f"int_{idx}" for idx in range(INT_FEATURE_COUNT)]
20 | DEFAULT_CAT_NAMES: List[str] = [f"cat_{idx}" for idx in range(CAT_FEATURE_COUNT)]
21 | DEFAULT_COLUMN_NAMES: List[str] = [
22 | DEFAULT_LABEL_NAME,
23 | *DEFAULT_INT_NAMES,
24 | *DEFAULT_CAT_NAMES,
25 | ]
26 | NUM_EMBEDDINGS_PER_FEATURE = [
27 | 40000000,
28 | 39060,
29 | 17295,
30 | 7424,
31 | 20265,
32 | 3,
33 | 7122,
34 | 1543,
35 | 63,
36 | 40000000,
37 | 3067956,
38 | 405282,
39 | 10,
40 | 2209,
41 | 11938,
42 | 155,
43 | 4,
44 | 976,
45 | 14,
46 | 40000000,
47 | 40000000,
48 | 40000000,
49 | 590152,
50 | 12973,
51 | 108,
52 | 36,
53 | ]
54 |
55 | NUM_EMBEDDINGS_PER_FEATURE_DICT = dict(
56 | zip(DEFAULT_CAT_NAMES, NUM_EMBEDDINGS_PER_FEATURE)
57 | )
58 |
--------------------------------------------------------------------------------
/examples/nvt_dataloader/aws_component.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | import os
9 |
10 | import torchx.specs as specs
11 | from torchx.components.dist import ddp
12 |
13 |
14 | def run_dlrm_main(num_trainers: int = 8, *script_args: str) -> specs.AppDef:
15 | """
16 | Args:
17 | num_trainers: The number of trainers to use.
18 | script_args: A variable number of parameters to provide dlrm_main.py.
19 | """
20 | cwd = os.getcwd()
21 | entrypoint = os.path.join(cwd, "train_torchrec.py")
22 |
23 | user = os.environ.get("USER")
24 | image = f"/data/home/{user}"
25 |
26 | if num_trainers > 8 and num_trainers % 8 != 0:
27 | raise ValueError(
28 | "Trainer jobs spanning multiple hosts must be in multiples of 8."
29 | )
30 | nproc_per_node = 8 if num_trainers >= 8 else num_trainers
31 | num_replicas = max(num_trainers // 8, 1)
32 |
33 | return ddp(
34 | *script_args,
35 | name="train_dlrm",
36 | image=image,
37 | # AWS p4d instance (https://aws.amazon.com/ec2/instance-types/p4/).
38 | cpu=96,
39 | gpu=8,
40 | memMB=-1,
41 | script=entrypoint,
42 | j=f"{num_replicas}x{nproc_per_node}",
43 | )
44 |
--------------------------------------------------------------------------------
/torchrec/distributed/model_tracker/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 |
10 | """Torchrec Model Tracker
11 |
12 | The model tracker module provides functionality to track and retrieve unique IDs and
13 | embeddings for supported modules during training. This is useful for identifying and
14 | retrieving the latest delta or unique rows for a model, which can help compute topk
15 | or to stream updated embeddings from predictors to trainers during online training.
16 |
17 | Key features include:
18 | - Tracking unique IDs and embeddings for supported modules
19 | - Support for multiple consumers with independent tracking
20 | - Configurable tracking modes (ID_ONLY, EMBEDDING)
21 | - Compaction of tracked data to reduce memory usage
22 | """
23 |
24 | from torchrec.distributed.model_tracker.delta_store import DeltaStore # noqa
25 | from torchrec.distributed.model_tracker.model_delta_tracker import (
26 | ModelDeltaTracker, # noqa
27 | SUPPORTED_MODULES, # noqa
28 | )
29 | from torchrec.distributed.model_tracker.types import (
30 | IndexedLookup, # noqa
31 | ModelTrackerConfigs, # noqa
32 | Trackers, # noqa
33 | TrackingMode, # noqa
34 | UniqueRows, # noqa
35 | UpdateMode, # noqa
36 | )
37 |
--------------------------------------------------------------------------------
/examples/golden_training/tests/test_train_dlrm.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 |
10 | import os
11 | import tempfile
12 | import unittest
13 | import uuid
14 |
15 | from torch.distributed.launcher.api import elastic_launch, LaunchConfig
16 | from torchrec.test_utils import skip_if_asan
17 |
18 | from ..train_dlrm import train
19 |
20 |
21 | class TrainTest(unittest.TestCase):
22 | @classmethod
23 | def _run_train(cls) -> None:
24 | train(
25 | embedding_dim=16,
26 | num_iterations=10,
27 | )
28 |
29 | @skip_if_asan
30 | def test_train_function(self) -> None:
31 | with tempfile.TemporaryDirectory() as tmpdir:
32 | lc = LaunchConfig(
33 | min_nodes=1,
34 | max_nodes=1,
35 | nproc_per_node=2,
36 | run_id=str(uuid.uuid4()),
37 | rdzv_backend="c10d",
38 | rdzv_endpoint=os.path.join(tmpdir, "rdzv"),
39 | rdzv_configs={"store_type": "file"},
40 | start_method="spawn",
41 | monitor_interval=1,
42 | max_restarts=0,
43 | )
44 |
45 | elastic_launch(config=lc, entrypoint=self._run_train)()
46 |
--------------------------------------------------------------------------------
/torchrec/sparse/tensor_dict.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 |
10 | from typing import List, Optional
11 |
12 | import torch
13 | from tensordict import TensorDict
14 |
15 | from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
16 |
17 |
18 | def maybe_td_to_kjt(
19 | features: KeyedJaggedTensor, keys: Optional[List[str]] = None
20 | ) -> KeyedJaggedTensor:
21 | if torch.jit.is_scripting():
22 | assert isinstance(features, KeyedJaggedTensor)
23 | return features
24 | if isinstance(features, TensorDict):
25 | if keys is None:
26 | keys = list(features.keys())
27 | values = torch.cat([features[key]._values for key in keys], dim=0)
28 | lengths = torch.cat(
29 | [
30 | (
31 | (features[key]._lengths)
32 | if features[key]._lengths is not None
33 | else torch.diff(features[key]._offsets)
34 | )
35 | for key in keys
36 | ],
37 | dim=0,
38 | )
39 | return KeyedJaggedTensor(
40 | keys=keys,
41 | values=values,
42 | lengths=lengths,
43 | )
44 | else:
45 | return features
46 |
--------------------------------------------------------------------------------
/torchrec/inference/inference_legacy/tests/ValidationTest.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | * All rights reserved.
4 | *
5 | * This source code is licensed under the BSD-style license found in the
6 | * LICENSE file in the root directory of this source tree.
7 | */
8 |
9 | #include "torchrec/inference/Validation.h"
10 |
11 | #include
12 | #include
13 |
14 | TEST(ValidationTest, validateSparseFeatures) {
15 | auto values = at::tensor({1, 2, 3, 4});
16 | auto lengths = at::tensor({1, 1, 1, 1});
17 | auto weights = at::tensor({.1, .2, .3, .4});
18 |
19 | // pass 1D
20 | EXPECT_TRUE(torchrec::validateSparseFeatures(values, lengths, weights));
21 |
22 | // pass 2D
23 | lengths.reshape({2, 2});
24 | EXPECT_TRUE(torchrec::validateSparseFeatures(values, lengths, weights));
25 |
26 | // fail 1D
27 | auto invalidLengths = at::tensor({1, 2, 1, 1});
28 | EXPECT_FALSE(
29 | torchrec::validateSparseFeatures(values, invalidLengths, weights));
30 |
31 | // fail 2D
32 | invalidLengths.reshape({2, 2});
33 | EXPECT_FALSE(
34 | torchrec::validateSparseFeatures(values, invalidLengths, weights));
35 | }
36 |
37 | TEST(ValidationTest, validateDenseFeatures) {
38 | auto values = at::tensor({1, 2, 3, 4});
39 | EXPECT_TRUE(torchrec::validateDenseFeatures(values, 1));
40 | EXPECT_TRUE(torchrec::validateDenseFeatures(values, 4));
41 | EXPECT_FALSE(torchrec::validateDenseFeatures(values, 3));
42 | }
43 |
--------------------------------------------------------------------------------
/torchrec/distributed/sharding/twcw_sharding.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 |
10 | from typing import Dict, List, Optional
11 |
12 | import torch
13 | from torchrec.distributed.embedding_sharding import EmbeddingShardingInfo
14 | from torchrec.distributed.sharding.cw_sharding import CwPooledEmbeddingSharding
15 | from torchrec.distributed.types import QuantizedCommCodecs, ShardingEnv
16 |
17 |
18 | class TwCwPooledEmbeddingSharding(CwPooledEmbeddingSharding):
19 | """
20 | Shards embedding bags table-wise column-wise, i.e.. a given embedding table is
21 | partitioned along its columns and the table slices are placed on all ranks
22 | within a host group.
23 | """
24 |
25 | def __init__(
26 | self,
27 | sharding_infos: List[EmbeddingShardingInfo],
28 | env: ShardingEnv,
29 | device: Optional[torch.device] = None,
30 | permute_embeddings: bool = False,
31 | qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None,
32 | ) -> None:
33 | super().__init__(
34 | sharding_infos,
35 | env,
36 | device,
37 | permute_embeddings=permute_embeddings,
38 | qcomm_codecs_registry=qcomm_codecs_registry,
39 | )
40 |
--------------------------------------------------------------------------------
/torchrec/ir/schema.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 |
10 | from dataclasses import dataclass
11 | from typing import List, Optional, Tuple
12 |
13 | from torchrec.modules.embedding_configs import DataType, PoolingType
14 |
15 |
16 | # Same as EmbeddingBagConfig but serializable
17 | @dataclass
18 | class EmbeddingBagConfigMetadata:
19 | num_embeddings: int
20 | embedding_dim: int
21 | name: str
22 | data_type: DataType
23 | feature_names: List[str]
24 | weight_init_max: Optional[float]
25 | weight_init_min: Optional[float]
26 | need_pos: bool
27 | pooling: PoolingType
28 |
29 |
30 | @dataclass
31 | class EBCMetadata:
32 | tables: List[EmbeddingBagConfigMetadata]
33 | is_weighted: bool
34 | device: Optional[str]
35 |
36 |
37 | @dataclass
38 | class FPEBCMetadata:
39 | is_fp_collection: bool
40 | features: List[str]
41 |
42 |
43 | @dataclass
44 | class PositionWeightedModuleMetadata:
45 | max_feature_length: int
46 |
47 |
48 | @dataclass
49 | class PositionWeightedModuleCollectionMetadata:
50 | max_feature_lengths: List[Tuple[str, int]]
51 |
52 |
53 | @dataclass
54 | class KTRegroupAsDictMetadata:
55 | groups: List[List[str]]
56 | keys: List[str]
57 | emb_dtype: Optional[str]
58 |
--------------------------------------------------------------------------------
/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | cmake_minimum_required(VERSION 3.11.0 FATAL_ERROR)
8 |
9 | project(TorchRec
10 | LANGUAGES CXX C)
11 |
12 | find_package(Torch REQUIRED)
13 |
14 | set(CMAKE_CXX_STANDARD 20)
15 |
16 | include(FetchContent)
17 |
18 | option(BUILD_TEST "Build C++ test binaries (need gtest and gbenchmark)" OFF)
19 |
20 | add_definitions("-D_GLIBCXX_USE_CXX11_ABI=0")
21 |
22 | add_subdirectory(torchrec/csrc)
23 |
24 | if (BUILD_TEST)
25 | FetchContent_Declare(googletest
26 | GIT_REPOSITORY https://github.com/google/googletest.git
27 | GIT_TAG v1.12.0
28 | )
29 | FetchContent_Declare(google_benchmark
30 | GIT_REPOSITORY https://github.com/google/benchmark.git
31 | GIT_TAG v1.5.6
32 | )
33 |
34 | # We will not need to test benchmark lib itself.
35 | set(BENCHMARK_ENABLE_TESTING OFF CACHE BOOL "Disable benchmark testing as we don't need it.")
36 | # We will not need to install benchmark since we link it statically.
37 | set(BENCHMARK_ENABLE_INSTALL OFF CACHE BOOL "Disable benchmark install to avoid overwriting vendor install.")
38 | FetchContent_MakeAvailable(googletest google_benchmark)
39 |
40 | enable_testing()
41 | add_subdirectory(benchmarks/cpp)
42 | add_subdirectory(test/cpp)
43 | endif()
44 |
--------------------------------------------------------------------------------
/benchmarks/cpp/dynamic_embedding/naive_id_transformer_benchmark.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | * All rights reserved.
4 | *
5 | * This source code is licensed under the BSD-style license found in the
6 | * LICENSE file in the root directory of this source tree.
7 | */
8 |
9 | #include
10 | #include
11 | #include
12 |
13 | namespace torchrec {
14 |
15 | static void BM_NaiveIDTransformer(benchmark::State& state) {
16 | NaiveIDTransformer transformer(2e8);
17 | torch::Tensor global_ids = torch::empty({1024, 1024}, torch::kLong);
18 | torch::Tensor cache_ids = torch::empty_like(global_ids);
19 | for (auto _ : state) {
20 | state.PauseTiming();
21 | global_ids.random_(state.range(0), state.range(1));
22 | state.ResumeTiming();
23 | transformer.transform(
24 | std::span{
25 | global_ids.template data_ptr(),
26 | static_cast(global_ids.numel())},
27 | std::span{
28 | cache_ids.template data_ptr(),
29 | static_cast(cache_ids.numel())});
30 | }
31 | }
32 |
33 | BENCHMARK(BM_NaiveIDTransformer)
34 | ->Iterations(100)
35 | ->Unit(benchmark::kMillisecond)
36 | ->ArgNames({"rand_from", "rand_to"})
37 | ->Args({static_cast(1e10), static_cast(2e10)})
38 | ->Args({static_cast(1e6), static_cast(2e6)});
39 |
40 | } // namespace torchrec
41 |
--------------------------------------------------------------------------------
/torchrec/distributed/tests/test_awaitable.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 |
10 | import unittest
11 |
12 | import torch
13 | from torchrec.distributed.types import Awaitable
14 |
15 |
16 | class AwaitableInstance(Awaitable[torch.Tensor]):
17 | def __init__(self) -> None:
18 | super().__init__()
19 |
20 | def _wait_impl(self) -> torch.Tensor:
21 | return torch.FloatTensor([1.0, 2.0, 3.0])
22 |
23 |
24 | class AwaitableTests(unittest.TestCase):
25 | def test_callback(self) -> None:
26 | awaitable = AwaitableInstance()
27 | # pyre-fixme[6]: For 1st argument expected `(Tensor) -> Tensor` but got
28 | # `(ret: Any) -> int`.
29 | awaitable.callbacks.append(lambda ret: 2 * ret)
30 | self.assertTrue(
31 | torch.allclose(awaitable.wait(), torch.FloatTensor([2.0, 4.0, 6.0]))
32 | )
33 |
34 | def test_callback_chained(self) -> None:
35 | awaitable = AwaitableInstance()
36 | # pyre-fixme[6]: For 1st argument expected `(Tensor) -> Tensor` but got
37 | # `(ret: Any) -> int`.
38 | awaitable.callbacks.append(lambda ret: 2 * ret)
39 | awaitable.callbacks.append(lambda ret: ret**2)
40 | self.assertTrue(
41 | torch.allclose(awaitable.wait(), torch.FloatTensor([4.0, 16.0, 36.0]))
42 | )
43 |
--------------------------------------------------------------------------------
/torchrec/distributed/composable/tests/test_table_batched_embedding_slice.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 |
10 | import copy
11 | import unittest
12 |
13 | import torch
14 |
15 | from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
16 | DenseTableBatchedEmbeddingBagsCodegen,
17 | )
18 | from torchrec.distributed.composable.table_batched_embedding_slice import (
19 | TableBatchedEmbeddingSlice,
20 | )
21 |
22 |
23 | class TestTableBatchedEmbeddingSlice(unittest.TestCase):
24 | def test_is_view(self) -> None:
25 | device = "cpu" if not torch.cuda.is_available() else "cuda"
26 | emb = DenseTableBatchedEmbeddingBagsCodegen(
27 | [(2, 4), (2, 4)], use_cpu=device == "cpu"
28 | )
29 | first_table = TableBatchedEmbeddingSlice(emb.weights, 0, 8, 2, 4)
30 | self.assertEqual(first_table.data_ptr(), emb.weights.data_ptr())
31 |
32 | def test_copy(self) -> None:
33 | device = "cpu" if not torch.cuda.is_available() else "cuda"
34 | emb = DenseTableBatchedEmbeddingBagsCodegen(
35 | [(2, 4), (2, 4)], use_cpu=device == "cpu"
36 | )
37 | first_table = TableBatchedEmbeddingSlice(emb.weights, 0, 8, 2, 4)
38 | copied = copy.deepcopy(first_table)
39 | self.assertNotEqual(first_table.data_ptr(), copied.data_ptr())
40 |
--------------------------------------------------------------------------------
/examples/retrieval/tests/test_two_tower_train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 |
10 | import os
11 | import tempfile
12 | import unittest
13 | import uuid
14 |
15 | from torch.distributed.launcher.api import elastic_launch, LaunchConfig
16 | from torchrec.test_utils import skip_if_asan
17 |
18 | # @manual=//torchrec/github/examples/retrieval:two_tower_train_lib
19 | from ..two_tower_train import train
20 |
21 |
22 | class TrainTest(unittest.TestCase):
23 | @classmethod
24 | def _run_train(cls) -> None:
25 | train(
26 | embedding_dim=16,
27 | layer_sizes=[16],
28 | num_iterations=10,
29 | )
30 |
31 | @skip_if_asan
32 | def test_train_function(self) -> None:
33 | with tempfile.TemporaryDirectory() as tmpdir:
34 | lc = LaunchConfig(
35 | min_nodes=1,
36 | max_nodes=1,
37 | nproc_per_node=2,
38 | run_id=str(uuid.uuid4()),
39 | rdzv_backend="c10d",
40 | rdzv_endpoint=os.path.join(tmpdir, "rdzv"),
41 | rdzv_configs={"store_type": "file"},
42 | start_method="spawn",
43 | monitor_interval=1,
44 | max_restarts=0,
45 | )
46 |
47 | elastic_launch(config=lc, entrypoint=self._run_train)()
48 |
--------------------------------------------------------------------------------
/torchrec/optim/fused.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 |
10 | import abc
11 | from typing import Any
12 |
13 | from torch import optim
14 | from torchrec.optim.keyed import KeyedOptimizer
15 |
16 |
17 | class FusedOptimizer(KeyedOptimizer, abc.ABC):
18 | """
19 | Assumes that weight update is done during backward pass,
20 | thus step() is a no-op.
21 | """
22 |
23 | @abc.abstractmethod
24 | # pyre-ignore [2]
25 | def step(self, closure: Any = None) -> None: ...
26 |
27 | @abc.abstractmethod
28 | def zero_grad(self, set_to_none: bool = False) -> None: ...
29 |
30 | def __repr__(self) -> str:
31 | return optim.Optimizer.__repr__(self)
32 |
33 |
34 | class EmptyFusedOptimizer(FusedOptimizer):
35 | """
36 | Fused Optimizer class with no-op step and no parameters to optimize over
37 | """
38 |
39 | def __init__(self) -> None:
40 | super().__init__({}, {}, {})
41 |
42 | # pyre-ignore
43 | def step(self, closure: Any = None) -> None:
44 | pass
45 |
46 | def zero_grad(self, set_to_none: bool = False) -> None:
47 | pass
48 |
49 |
50 | class FusedOptimizerModule(abc.ABC):
51 | """
52 | Module, which does weight update during backward pass.
53 | """
54 |
55 | @property
56 | @abc.abstractmethod
57 | def fused_optimizer(self) -> KeyedOptimizer: ...
58 |
--------------------------------------------------------------------------------
/torchrec/distributed/train_pipeline/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 |
10 |
11 | from torchrec.distributed.train_pipeline.pipeline_context import ( # noqa
12 | In,
13 | Out,
14 | TrainPipelineContext,
15 | )
16 | from torchrec.distributed.train_pipeline.pipeline_stage import ( # noqa
17 | SparseDataDistUtil, # noqa
18 | StageOut, # noqa
19 | )
20 | from torchrec.distributed.train_pipeline.tracing import ( # noqa
21 | ArgInfoStepFactory, # noqa
22 | Tracer, # noqa
23 | )
24 | from torchrec.distributed.train_pipeline.train_pipelines import ( # noqa
25 | EvalPipelineSparseDist, # noqa
26 | PrefetchTrainPipelineSparseDist, # noqa
27 | StagedTrainPipeline, # noqa
28 | TorchCompileConfig, # noqa
29 | TrainPipeline, # noqa
30 | TrainPipelineBase, # noqa
31 | TrainPipelineFusedSparseDist, # noqa
32 | TrainPipelinePT2, # noqa
33 | TrainPipelineSparseDist, # noqa
34 | TrainPipelineSparseDistCompAutograd, # noqa
35 | )
36 | from torchrec.distributed.train_pipeline.types import ArgInfo, CallArgs # noqa
37 | from torchrec.distributed.train_pipeline.utils import ( # noqa
38 | _override_input_dist_forwards, # noqa
39 | _rewrite_model, # noqa
40 | _start_data_dist, # noqa
41 | _to_device, # noqa
42 | _wait_for_batch, # noqa
43 | DataLoadingThread, # noqa
44 | )
45 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) Meta Platforms, Inc. and affiliates.
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | * Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/torchrec/quant/tests/test_tensor_types.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 |
10 | #!/usr/bin/env python3
11 |
12 | import unittest
13 |
14 | import torch
15 |
16 | from torchrec.tensor_types import UInt2Tensor, UInt4Tensor
17 |
18 |
19 | class QuantUtilsTest(unittest.TestCase):
20 | # pyre-ignore
21 | @unittest.skipIf(
22 | torch.cuda.device_count() <= 0,
23 | "Not enough GPUs available",
24 | )
25 | def test_uint42_tensor(self) -> None:
26 | t_u8 = torch.tensor(
27 | [
28 | [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
29 | [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
30 | [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
31 | ],
32 | dtype=torch.uint8,
33 | )
34 | t_u4 = UInt4Tensor(t_u8)
35 | t_u4.detach()
36 |
37 | t_u4.to(torch.device("cuda"))
38 | assert torch.equal(t_u4.view(torch.uint8), t_u8)
39 | t_u2 = UInt2Tensor(t_u8)
40 | t_u2.to(torch.device("cuda"))
41 | assert torch.equal(t_u2.view(torch.uint8), t_u8)
42 |
43 | for t in [t_u4[:, :8], t_u4[:, 8:]]:
44 | assert t.size(1) == 8
45 | t_u4[:, :8].copy_(t_u4[:, 8:])
46 |
47 | for t in [t_u2[:, 4:8], t_u2[:, 8:12]]:
48 | assert t.size(1) == 4
49 |
50 | t_u2[:, 4:8].copy_(t_u2[:, 8:12])
51 |
--------------------------------------------------------------------------------
/torchrec/inference/inference_legacy/tests/test_modules.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # pyre-strict
4 |
5 | # Copyright (c) Meta Platforms, Inc. and affiliates.
6 | # All rights reserved.
7 | #
8 | # This source code is licensed under the BSD-style license found in the
9 | # LICENSE file in the root directory of this source tree.
10 |
11 | #!/usr/bin/env python3
12 | # @nolint
13 |
14 | import unittest
15 |
16 | from torchrec.distributed.test_utils.infer_utils import TorchTypesModelInputWrapper
17 | from torchrec.distributed.test_utils.test_model import TestSparseNN
18 | from torchrec.inference.modules import quantize_inference_model, shard_quant_model
19 | from torchrec.modules.embedding_configs import EmbeddingBagConfig
20 |
21 |
22 | class EagerModelProcessingTests(unittest.TestCase):
23 | def test_quantize_shard_cuda(self) -> None:
24 | tables = [
25 | EmbeddingBagConfig(
26 | num_embeddings=10,
27 | embedding_dim=4,
28 | name="table_" + str(i),
29 | feature_names=["feature_" + str(i)],
30 | )
31 | for i in range(10)
32 | ]
33 |
34 | model = TorchTypesModelInputWrapper(
35 | TestSparseNN(
36 | tables=tables,
37 | )
38 | )
39 |
40 | quantized_model = quantize_inference_model(model)
41 | sharded_model, _ = shard_quant_model(quantized_model)
42 |
43 | # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `sparse`.
44 | sharded_qebc = sharded_model._module.sparse.ebc
45 | self.assertEqual(len(sharded_qebc.tbes), 1)
46 |
--------------------------------------------------------------------------------
/contrib/dynamic_embedding/src/tde/details/url_test.cpp:
--------------------------------------------------------------------------------
1 | #include "gtest/gtest.h"
2 | #include "tde/details/url.h"
3 |
4 | namespace tde::details::url_parser::rules {
5 |
6 | TEST(TDE, url_token) {
7 | auto ipt = lexy::string_input(std::string_view("%61"));
8 | auto parse = lexy::parse(ipt, lexy_ext::report_error);
9 | ASSERT_TRUE(parse.has_value());
10 | ASSERT_EQ('a', parse.value());
11 | }
12 |
13 | TEST(TDE, url_string) {
14 | auto ipt = lexy::string_input(std::string_view("%61bc"));
15 | auto parse = lexy::parse(ipt, lexy_ext::report_error);
16 | ASSERT_TRUE(parse.has_value());
17 | ASSERT_EQ("abc", parse.value());
18 | }
19 |
20 | TEST(TDE, url_normal) {
21 | auto ipt = lexy::string_input(std::string_view("a@"));
22 | auto parse = lexy::parse(ipt, lexy_ext::report_error);
23 | ASSERT_TRUE(parse.has_value());
24 | ASSERT_EQ("a", parse.value().username_);
25 | ASSERT_FALSE(parse.value().password_.has_value());
26 | // ASSERT_EQ("abc", parse.value());
27 | }
28 |
29 | TEST(TDE, url_host) {
30 | auto ipt = lexy::string_input(std::string_view("www.qq.com"));
31 | auto parse = lexy::parse(ipt, lexy_ext::report_error);
32 | ASSERT_TRUE(parse.has_value());
33 | ASSERT_EQ("www.qq.com", parse.value());
34 | }
35 |
36 | TEST(TDE, url) {
37 | auto url = ParseUrl("www.qq.com/?a=b&&c=d");
38 | ASSERT_EQ(url.host_, "www.qq.com");
39 | ASSERT_TRUE(url.param_.has_value());
40 | ASSERT_EQ("a=b&&c=d", url.param_.value());
41 | }
42 |
43 | TEST(TDE, bad_url) {
44 | ASSERT_ANY_THROW([] { ParseUrl("blablah!@@"); }());
45 | }
46 |
47 | } // namespace tde::details::url_parser::rules
48 |
--------------------------------------------------------------------------------
/examples/nvt_dataloader/README.md:
--------------------------------------------------------------------------------
1 | # Running torchrec using NVTabular DataLoader
2 |
3 | First run nvtabular preprocessing to first convert the criteo TSV files to parquet, and perform offline preprocessing.
4 |
5 | Please follow the installation instructions in the [README](https://github.com/pytorch/torchrec/tree/main/torchrec/datasets/scripts/nvt) of torchrec/torchrec/datasets/scripts/nvt.
6 |
7 | Afterward, to run the model across 8 GPUs, use the below command
8 |
9 | ```
10 | torchx run -s local_cwd dist.ddp -j 1x8 --script train_torchrec.py -- --num_embeddings_per_feature 40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36 --over_arch_layer_sizes 1024,1024,512,256,1 --dense_arch_layer_sizes 512,256,128 --embedding_dim 128 --binary_path /criteo_binary/split/ --learning_rate 1.0 --validation_freq_within_epoch 1000000 --throughput_check_freq_within_epoch 1000000 --batch_size 256
11 | ```
12 |
13 | To run with adagrad as an optimizer, use the below flag
14 |
15 | ```
16 | ---adagrad
17 | ```
18 |
19 | # Test on A100s
20 |
21 | ## Preliminary Training Results
22 |
23 | **Setup:**
24 | * Dataset: Criteo 1TB Click Logs dataset
25 | * CUDA 11.1, NCCL 2.10.3.
26 | * AWS p4d24xlarge instances, each with 8 40GB NVIDIA A100s.
27 |
28 | **Results**
29 |
30 | Reproducing MLPerfV1 settings
31 | 1. Embedding per features + model architecture
32 | 2. Learning Rate fixed at 1.0 with SGD
33 | 3. Dataset setup:
34 | - No frequency thresholding
35 | 4. Report > .8025 on validation set (0.8027645945549011 from above script)
36 | 5. Global batch size 2048
37 |
--------------------------------------------------------------------------------
/examples/retrieval/data/dataloader.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 |
10 | from torch.utils.data import DataLoader
11 | from torchrec.datasets.movielens import DEFAULT_RATINGS_COLUMN_NAMES
12 | from torchrec.datasets.random import RandomRecDataset
13 |
14 |
15 | def get_dataloader(
16 | batch_size: int, num_embeddings: int, pin_memory: bool = False, num_workers: int = 0
17 | ) -> DataLoader:
18 | """
19 | Gets a Random dataloader for the two tower model, containing a two_feature KJT as sparse_features, empty dense_features
20 | and binary labels
21 |
22 | Args:
23 | batch_size (int): batch_size
24 | num_embeddings (int): hash_size of the two embedding tables
25 | pin_memory (bool): Whether to pin_memory on the GPU
26 | num_workers (int) Number of dataloader workers
27 |
28 | Returns:
29 | dataloader (DataLoader): PyTorch dataloader for the specified options.
30 |
31 | """
32 | two_tower_column_names = DEFAULT_RATINGS_COLUMN_NAMES[:2]
33 |
34 | return DataLoader(
35 | RandomRecDataset(
36 | keys=two_tower_column_names,
37 | batch_size=batch_size,
38 | hash_size=num_embeddings,
39 | ids_per_feature=1,
40 | num_dense=0,
41 | ),
42 | batch_size=None,
43 | batch_sampler=None,
44 | pin_memory=pin_memory,
45 | num_workers=num_workers,
46 | )
47 |
--------------------------------------------------------------------------------
/torchrec/datasets/scripts/tests/test_npy_preproc_criteo.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 |
10 | import os
11 | import tempfile
12 | import unittest
13 |
14 | import numpy as np
15 | from torchrec.datasets.criteo import CAT_FEATURE_COUNT, INT_FEATURE_COUNT
16 | from torchrec.datasets.scripts.npy_preproc_criteo import main
17 | from torchrec.datasets.test_utils.criteo_test_utils import CriteoTest
18 |
19 |
20 | class MainTest(unittest.TestCase):
21 | def test_main(self) -> None:
22 | num_rows = 10
23 | name = "day_0"
24 | with CriteoTest._create_dataset_tsv(
25 | num_rows=num_rows,
26 | filename=name,
27 | ) as in_file_path, tempfile.TemporaryDirectory() as output_dir:
28 | main(
29 | [
30 | "--input_dir",
31 | os.path.dirname(in_file_path),
32 | "--output_dir",
33 | output_dir,
34 | ]
35 | )
36 |
37 | dense = np.load(os.path.join(output_dir, name + "_dense.npy"))
38 | sparse = np.load(os.path.join(output_dir, name + "_sparse.npy"))
39 | labels = np.load(os.path.join(output_dir, name + "_labels.npy"))
40 |
41 | self.assertEqual(dense.shape, (num_rows, INT_FEATURE_COUNT))
42 | self.assertEqual(sparse.shape, (num_rows, CAT_FEATURE_COUNT))
43 | self.assertEqual(labels.shape, (num_rows, 1))
44 |
--------------------------------------------------------------------------------
/torchrec/inference/inference_legacy/src/Validation.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | * All rights reserved.
4 | *
5 | * This source code is licensed under the BSD-style license found in the
6 | * LICENSE file in the root directory of this source tree.
7 | */
8 |
9 | #include "torchrec/inference/Validation.h"
10 | #include "ATen/Functions.h"
11 |
12 | namespace torchrec {
13 |
14 | bool validateSparseFeatures(
15 | at::Tensor& values,
16 | at::Tensor& lengths,
17 | std::optional maybeWeights) {
18 | auto flatLengths = lengths.view(-1);
19 |
20 | // validate sum of lengths equals number of values/weights
21 | auto lengthsTotal = at::sum(flatLengths).item();
22 | if (lengthsTotal != values.size(0)) {
23 | return false;
24 | }
25 | if (maybeWeights.has_value() && lengthsTotal != maybeWeights->size(0)) {
26 | return false;
27 | }
28 |
29 | // Validate no negative values in lengths.
30 | // Use faster path if contiguous.
31 | if (flatLengths.is_contiguous()) {
32 | int* ptr = (int*)flatLengths.data_ptr();
33 | for (int i = 0; i < flatLengths.numel(); ++i) {
34 | if (*ptr < 0) {
35 | return false;
36 | }
37 | ptr++;
38 | }
39 | } else {
40 | // accessor does boundary check (slower)
41 | auto acc = flatLengths.accessor();
42 | for (int i = 0; i < acc.size(0); i++) {
43 | if (acc[i] < 0) {
44 | return false;
45 | }
46 | }
47 | }
48 |
49 | return true;
50 | }
51 |
52 | bool validateDenseFeatures(at::Tensor& values, size_t batchSize) {
53 | return values.size(0) % batchSize == 0;
54 | }
55 |
56 | } // namespace torchrec
57 |
--------------------------------------------------------------------------------
/benchmarks/cpp/dynamic_embedding/mixed_lfu_lru_strategy_benchmark.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | * All rights reserved.
4 | *
5 | * This source code is licensed under the BSD-style license found in the
6 | * LICENSE file in the root directory of this source tree.
7 | */
8 |
9 | #include
10 | #include
11 |
12 | namespace torchrec {
13 | void BM_MixedLFULRUStrategy(benchmark::State& state) {
14 | size_t num_ext_values = state.range(0);
15 | std::vector ext_values(num_ext_values);
16 |
17 | MixedLFULRUStrategy strategy;
18 | for (auto& v : ext_values) {
19 | v = strategy.update(0, 0, std::nullopt);
20 | }
21 |
22 | size_t num_elems = state.range(1);
23 | std::default_random_engine engine((std::random_device())());
24 | size_t time = 0;
25 | for (auto _ : state) {
26 | state.PauseTiming();
27 | std::vector offsets;
28 | offsets.reserve(num_elems);
29 | for (size_t i = 0; i < num_elems; ++i) {
30 | std::uniform_int_distribution dist(0, num_elems - 1);
31 | offsets.emplace_back(dist(engine));
32 | }
33 | state.ResumeTiming();
34 |
35 | ++time;
36 | strategy.update_time(time);
37 | for (auto& v : offsets) {
38 | ext_values[v] = strategy.update(0, 0, ext_values[v]);
39 | }
40 | }
41 | }
42 |
43 | BENCHMARK(BM_MixedLFULRUStrategy)
44 | ->ArgNames({"num_ext_values", "num_elems_per_iter"})
45 | ->Args({30000000, 1024 * 1024})
46 | ->Args({300000000, 1024 * 1024})
47 | ->Unit(benchmark::kMillisecond)
48 | ->Iterations(100);
49 |
50 | } // namespace torchrec
51 |
--------------------------------------------------------------------------------
/contrib/dynamic_embedding/src/tde/details/move_only_function.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 | #include
5 |
6 | namespace tde::details {
7 |
8 | template
9 | class MoveOnlyFunction;
10 |
11 | template
12 | class MoveOnlyFunction {
13 | public:
14 | MoveOnlyFunction() = default;
15 |
16 | template
17 | /*implicit*/ MoveOnlyFunction(F f) : f_(new Derived(std::move(f))) {}
18 |
19 | MoveOnlyFunction(const MoveOnlyFunction&) = delete;
20 | MoveOnlyFunction(MoveOnlyFunction&& o) noexcept
21 | : f_(std::move(o.f_)) {}
22 |
23 | MoveOnlyFunction& operator=(const MoveOnlyFunction&) =
24 | delete;
25 | MoveOnlyFunction& operator=(
26 | MoveOnlyFunction&&) noexcept = default;
27 |
28 | template
29 | MoveOnlyFunction& operator=(F f) {
30 | (*this) = MoveOnlyFunction(std::move(f));
31 | return *this;
32 | }
33 |
34 | R operator()(Args&&... args) {
35 | return (*f_)(std::forward(args)...);
36 | }
37 |
38 | /*implicit*/ operator bool() const {
39 | return f_ != nullptr;
40 | }
41 |
42 | private:
43 | struct Base {
44 | virtual ~Base() = default;
45 | virtual R operator()(Args&&...) = 0;
46 | };
47 |
48 | template
49 | struct Derived final : public Base {
50 | explicit Derived(F f) : f_(std::move(f)) {}
51 | R operator()(Args&&... args) override {
52 | return f_(std::forward(args)...);
53 | }
54 | F f_;
55 | };
56 | std::unique_ptr f_;
57 | };
58 |
59 | } // namespace tde::details
60 |
--------------------------------------------------------------------------------
/torchrec/modules/activation.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 |
10 | """
11 | Activation Modules
12 | """
13 |
14 | from typing import List, Optional, Union
15 |
16 | import torch
17 | from torch import nn
18 |
19 |
20 | class SwishLayerNorm(nn.Module):
21 | """
22 | Applies the Swish function with layer normalization: `Y = X * Sigmoid(LayerNorm(X))`.
23 |
24 | Args:
25 | input_dims (Union[int, List[int], torch.Size]): dimensions to normalize over.
26 | If an input tensor has shape [batch_size, d1, d2, d3], setting
27 | input_dim=[d2, d3] will do the layer normalization on last two dimensions.
28 | device (Optional[torch.device]): default compute device.
29 |
30 | Example::
31 |
32 | sln = SwishLayerNorm(100)
33 | """
34 |
35 | def __init__(
36 | self,
37 | input_dims: Union[int, List[int], torch.Size],
38 | device: Optional[torch.device] = None,
39 | ) -> None:
40 | super().__init__()
41 | self.norm: torch.nn.modules.Sequential = nn.Sequential(
42 | nn.LayerNorm(input_dims, device=device),
43 | nn.Sigmoid(),
44 | )
45 |
46 | def forward(
47 | self,
48 | input: torch.Tensor,
49 | ) -> torch.Tensor:
50 | """
51 | Args:
52 | input (torch.Tensor): an input tensor.
53 |
54 | Returns:
55 | torch.Tensor: an output tensor.
56 | """
57 | return input * self.norm(input)
58 |
--------------------------------------------------------------------------------
/examples/ray/compute_world_size.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | #!/usr/bin/env python3
9 |
10 | import os
11 |
12 | import torch
13 | import torch.nn.functional as F
14 | from torch.distributed import all_reduce, get_rank, get_world_size, init_process_group
15 |
16 |
17 | def compute_world_size() -> int:
18 | "Dummy script to compute world_size. Meant to test if can run Ray + Pytorch DDP"
19 | rank = int(os.getenv("RANK")) # pyre-ignore[6]
20 | world_size = int(os.getenv("WORLD_SIZE")) # pyre-ignore[6]
21 | master_port = int(os.getenv("MASTER_PORT")) # pyre-ignore[6]
22 | master_addr = os.getenv("MASTER_ADDR")
23 | backend = "gloo"
24 |
25 | print(f"initializing `{backend}` process group")
26 | init_process_group( # pyre-ignore[16]
27 | backend=backend,
28 | init_method=f"tcp://{master_addr}:{master_port}",
29 | rank=rank,
30 | world_size=world_size,
31 | )
32 | print("successfully initialized process group")
33 |
34 | rank = get_rank() # pyre-ignore[16]
35 | world_size = get_world_size() # pyre-ignore[16]
36 |
37 | t = F.one_hot(torch.tensor(rank), num_classes=world_size)
38 | all_reduce(t) # pyre-ignore[16]
39 | computed_world_size = int(torch.sum(t).item())
40 | print(
41 | f"rank: {rank}, actual world_size: {world_size}, computed world_size: {computed_world_size}"
42 | )
43 | return computed_world_size
44 |
45 |
46 | def main() -> None:
47 | compute_world_size()
48 |
49 |
50 | if __name__ == "__main__":
51 | main()
52 |
--------------------------------------------------------------------------------
/examples/ray/README.md:
--------------------------------------------------------------------------------
1 | # Running torchrec with torchx using Ray scheduler on a Ray cluster
2 |
3 | ```
4 | pip install --pre torchrec -f https://download.pytorch.org/whl/torchrec/index.html
5 | pip install torchx-nightly
6 | pip install "ray[default]" -qqq
7 | ```
8 |
9 | Run torchx with the dashboard address and a link to your component
10 | ```
11 | torchx run -s ray -cfg dashboard_address=localhost:6379,working_dir=~/repos/torchrec/examples/ray,requirements=./requirements.txt dist.ddp -j 1x2 --script ~/repos/torchrec/examples/ray/train_torchrec.py
12 | ```
13 |
14 | Or run locally
15 | ```
16 | torchx run -s ray -cfg working_dir=~/repos/torchrec/examples/ray,requirements=./requirements.txt dist.ddp -j 1x2 --script ~/repos/torchrec/examples/ray/train_torchrec.py
17 | ```
18 |
19 | To run with GPUs, add --gpus as a flag to the dist.ddp component
20 | For other dist.ddp options, see https://pytorch.org/torchx/latest/components/distributed.html
21 | ```
22 | torchx run -s ray -cfg working_dir=~/repos/torchrec/examples/ray,requirements=./requirements.txt dist.ddp -j 1x2 --gpu 2 --script ~/repos/torchrec/examples/ray/train_torchrec.py
23 | ```
24 |
25 | To run w/o ray scheduler (only torchx)
26 | For available settings https://pytorch.org/torchx/latest/cli.html?highlight=torchx%20run
27 | ```
28 | torchx run -s local_cwd dist.ddp -j 1x2 --script ~/repos/torchrec/examples/ray/train_torchrec.py
29 | ```
30 |
31 | Job ID looks like ray://torchx/172.31.16.248:6379-raysubmit_ntquG1dDV6CtFUC5
32 | Replace the job ID below by your string
33 |
34 |
35 | Get a job status
36 | PENDING, FAILED, INTERRUPTED ETC..
37 | ```
38 | torchx status ray://torchx/172.31.16.248:6379-raysubmit_ntquG1dDV6CtFUC5
39 | ```
40 |
41 | Get logs
42 | ```
43 | torchx log ray://torchx/172.31.16.248:6379-raysubmit_ntquG1dDV6CtFUC5/worker/0
44 | ```
45 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to torchrec
2 | We want to make contributing to this project as easy and transparent as
3 | possible.
4 |
5 | ## Pull Requests
6 | We actively welcome your pull requests.
7 |
8 | 1. Fork the repo and create your branch from `main`.
9 | 2. If you've added code that should be tested, add tests.
10 | 3. If you've changed APIs, update the documentation.
11 | 4. Ensure the test suite passes.
12 | 5. Make sure your code lints. You can use [lintrunner](https://github.com/pytorch/pytorch/wiki/lintrunner) to do so.
13 | 1. To set up:
14 | ```
15 | pip install lintrunner
16 | lintrunner init
17 | ```
18 | 2. To lint your local changes:
19 | ```
20 | lintrunner
21 | ```
22 | 3. To format locally changed files:
23 | ```
24 | lintrunner f
25 | ```
26 | 4. To lint all files:
27 | ```
28 | lintrunner --all-files
29 | ```
30 | 6. If you haven't already, complete the Contributor License Agreement ("CLA").
31 |
32 | ## Contributor License Agreement ("CLA")
33 | In order to accept your pull request, we need you to submit a CLA. You only need
34 | to do this once to work on any of Facebook's open source projects.
35 |
36 | Complete your CLA here:
37 |
38 | ## Issues
39 | We use GitHub issues to track public bugs. Please ensure your description is
40 | clear and has sufficient instructions to be able to reproduce the issue.
41 |
42 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
43 | disclosure of security bugs. In those cases, please go through the process
44 | outlined on that page and do not file a public issue.
45 |
46 | ## License
47 | By contributing to torchrec, you agree that your contributions will be licensed
48 | under the LICENSE file in the root directory of this source tree.
49 |
--------------------------------------------------------------------------------
/torchrec/inference/include/torchrec/inference/Exception.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | * All rights reserved.
4 | *
5 | * This source code is licensed under the BSD-style license found in the
6 | * LICENSE file in the root directory of this source tree.
7 | */
8 |
9 | #pragma once
10 | #include
11 |
12 | namespace torchrec {
13 |
14 | // We have different error code defined for different kinds of exceptions in
15 | // fblearner/sigrid predictor. (Code pointer:
16 | // fblearner/predictor/if/prediction_service.thrift.) We define different
17 | // exception type here so that in fblearner/sigrid predictor we can detect the
18 | // exception type and return the corresponding error code to reflect the right
19 | // info.
20 | class TorchrecException : public std::runtime_error {
21 | public:
22 | explicit TorchrecException(const std::string& error)
23 | : std::runtime_error(error) {}
24 | };
25 |
26 | // GPUOverloadException maps to
27 | // PredictionExceptionCode::GPU_BATCHING_QUEUE_TIMEOUT
28 | class GPUOverloadException : public TorchrecException {
29 | public:
30 | explicit GPUOverloadException(const std::string& error)
31 | : TorchrecException(error) {}
32 | };
33 |
34 | // GPUExecutorOverloadException maps to
35 | // PredictionExceptionCode::GPU_EXECUTOR_QUEUE_TIMEOUT
36 | class GPUExecutorOverloadException : public TorchrecException {
37 | public:
38 | explicit GPUExecutorOverloadException(const std::string& error)
39 | : TorchrecException(error) {}
40 | };
41 |
42 | // TorchDeployException maps to
43 | // PredictorUserErrorCode::TORCH_DEPLOY_ERROR
44 | class TorchDeployException : public TorchrecException {
45 | public:
46 | explicit TorchDeployException(const std::string& error)
47 | : TorchrecException(error) {}
48 | };
49 | } // namespace torchrec
50 |
--------------------------------------------------------------------------------
/contrib/dynamic_embedding/src/tde/details/id_transformer_variant_impl.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | namespace tde::details {
4 |
5 | template
6 | inline bool IDTransformer::Transform(
7 | tcb::span global_ids,
8 | tcb::span cache_ids,
9 | Fetch fetch) {
10 | return strategy_.VisitUpdator([&](auto&& update) -> bool {
11 | return std::visit(
12 | [&](auto&& transformer) -> bool {
13 | return transformer.Transform(
14 | global_ids,
15 | cache_ids,
16 | std::forward(update),
17 | std::move(fetch));
18 | },
19 | var_);
20 | });
21 | }
22 |
23 | template
24 | inline auto IDTransformer::LXUStrategy::VisitUpdator(Visitor visit)
25 | -> std::invoke_result_t {
26 | return std::visit(
27 | [&](auto& s) {
28 | using T = typename std::decay_t::lxu_record_t;
29 | auto update =
30 | [&](std::optional record, int64_t global_id, int64_t cache_id) {
31 | return s.Update(global_id, cache_id, record);
32 | };
33 | return visit(update);
34 | },
35 | strategy_);
36 | }
37 |
38 | template
39 | int64_t IDTransformer::LXUStrategy::Time(T record) {
40 | return std::visit(
41 | [&](auto& s) -> int64_t { return s.Time(record); }, strategy_);
42 | }
43 |
44 | template
45 | inline std::vector IDTransformer::LXUStrategy::Evict(
46 | Iterator iterator,
47 | uint64_t num_to_evict) {
48 | return std::visit(
49 | [&, iterator = std::move(iterator)](auto& s) mutable {
50 | return s.Evict(std::move(iterator), num_to_evict);
51 | },
52 | strategy_);
53 | }
54 |
55 | } // namespace tde::details
56 |
--------------------------------------------------------------------------------
/torchrec/csrc/dynamic_embedding/details/bitmap_impl.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | * All rights reserved.
4 | *
5 | * This source code is licensed under the BSD-style license found in the
6 | * LICENSE file in the root directory of this source tree.
7 | */
8 |
9 | #pragma once
10 | #include
11 | #include
12 |
13 | namespace torchrec {
14 |
15 | template
16 | inline Bitmap::Bitmap(int64_t num_bits)
17 | : num_total_bits_(num_bits),
18 | num_values_((num_bits + num_bits_per_value - 1) / num_bits_per_value),
19 | values_(new T[num_values_]),
20 | next_free_bit_(0) {
21 | std::fill(values_.get(), values_.get() + num_values_, -1);
22 | }
23 |
24 | template
25 | inline int64_t Bitmap::next_free_bit() {
26 | int64_t result = next_free_bit_;
27 | int64_t offset = result / num_bits_per_value;
28 | T value = values_[offset];
29 | // set the last 1 bit to zero
30 | values_[offset] = value & (value - 1);
31 | while (values_[offset] == 0 && offset < num_values_) {
32 | offset++;
33 | }
34 | value = values_[offset];
35 | if (C10_LIKELY(value)) {
36 | next_free_bit_ = offset * num_bits_per_value + ctz(value);
37 | } else {
38 | next_free_bit_ = num_total_bits_;
39 | }
40 |
41 | return result;
42 | }
43 |
44 | template
45 | inline void Bitmap::free_bit(int64_t offset) {
46 | int64_t mask_offset = offset / num_bits_per_value;
47 | int64_t bit_offset = offset % num_bits_per_value;
48 | values_[mask_offset] |= 1 << bit_offset;
49 | next_free_bit_ = std::min(offset, next_free_bit_);
50 | }
51 | template
52 | inline bool Bitmap::full() const {
53 | return next_free_bit_ >= num_total_bits_;
54 | }
55 |
56 | } // namespace torchrec
57 |
--------------------------------------------------------------------------------
/torchrec/inference/inference_legacy/tests/predict_module_tests.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | import unittest
9 | from typing import Dict
10 |
11 | import torch
12 | import torch.nn as nn
13 | from torchrec.inference.modules import PredictModule, quantize_dense
14 |
15 |
16 | class TestModule(nn.Module):
17 | def __init__(self) -> None:
18 | super().__init__()
19 | self.linear0 = nn.Linear(10, 1)
20 | self.linear1 = nn.Linear(1, 1)
21 |
22 |
23 | class TestPredictModule(PredictModule):
24 | def predict_forward(
25 | self, batch: Dict[str, torch.Tensor]
26 | ) -> Dict[str, torch.Tensor]:
27 | return self.predict_module(*batch)
28 |
29 |
30 | class PredictModulesTest(unittest.TestCase):
31 | def test_predict_module(self) -> None:
32 | module = TestModule()
33 | predict_module = TestPredictModule(module)
34 |
35 | module_state_dict = module.state_dict()
36 | predict_module_state_dict = predict_module.state_dict()
37 |
38 | self.assertEqual(module_state_dict.keys(), predict_module_state_dict.keys())
39 |
40 | for tensor0, tensor1 in zip(
41 | module_state_dict.values(), predict_module_state_dict.values()
42 | ):
43 | self.assertTrue(torch.equal(tensor0, tensor1))
44 |
45 | def test_dense_lowering(self) -> None:
46 | module = TestModule()
47 | predict_module = TestPredictModule(module)
48 | predict_module = quantize_dense(predict_module, torch.half)
49 | for param in predict_module.parameters():
50 | self.assertEqual(param.dtype, torch.half)
51 |
--------------------------------------------------------------------------------
/torchrec/streamable.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 |
10 | import abc
11 |
12 | import torch
13 |
14 |
15 | class Multistreamable(abc.ABC):
16 | """
17 | Objects implementing this interface are allowed to be transferred
18 | from one CUDA stream to another.
19 | torch.Tensor and (Keyed)JaggedTensor implement this interface.
20 | """
21 |
22 | @abc.abstractmethod
23 | def record_stream(self, stream: torch.Stream) -> None:
24 | """
25 | See https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html
26 | """
27 | ...
28 |
29 |
30 | class Pipelineable(Multistreamable):
31 | """
32 | This interface contains two methods, one for moving an input across devices,
33 | the other one for marking streams that operate the input.
34 |
35 | torch.Tensor implements this interface and we can used it in many applications.
36 | Another example is torchrec.(Keyed)JaggedTensor, which we use as the input to
37 | torchrec.EmbeddingBagCollection, which in turn is often the first layer of many models.
38 | Some models take compound inputs, which should implement this interface.
39 | """
40 |
41 | @abc.abstractmethod
42 | def to(self, device: torch.device, non_blocking: bool) -> "Pipelineable":
43 | """
44 | Please be aware that according to https://pytorch.org/docs/stable/generated/torch.Tensor.to.html,
45 | `to` might return self or a copy of self. So please remember to use `to` with the assignment operator,
46 | for example, `in = in.to(new_device)`.
47 | """
48 | ...
49 |
--------------------------------------------------------------------------------
/torchrec/csrc/dynamic_embedding/bind.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | * All rights reserved.
4 | *
5 | * This source code is licensed under the BSD-style license found in the
6 | * LICENSE file in the root directory of this source tree.
7 | */
8 |
9 | #include
10 |
11 | #include
12 | #include
13 | #include
14 |
15 | namespace torchrec {
16 | TORCH_LIBRARY(tde, m) {
17 | m.def("register_io", [](const std::string& name) {
18 | IORegistry::Instance().register_plugin(name.c_str());
19 | });
20 |
21 | m.class_("TransformResult")
22 | .def_readonly("success", &TransformResult::success)
23 | .def_readonly("ids_to_fetch", &TransformResult::ids_to_fetch);
24 |
25 | m.class_("IDTransformer")
26 | .def(torch::init