├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── Dockerfile ├── LICENSE ├── README.md ├── bench ├── dlrm_s_benchmark.sh ├── dlrm_s_criteo_kaggle.sh ├── dlrm_s_criteo_terabyte.sh └── run_and_time.sh ├── cython ├── cython_compile.py └── cython_criteo.py ├── data_loader_terabyte.py ├── data_utils.py ├── dlrm_data_caffe2.py ├── dlrm_data_pytorch.py ├── dlrm_s_caffe2.py ├── dlrm_s_pytorch.py ├── extend_distributed.py ├── input ├── dist_emb_0.log ├── dist_emb_1.log ├── dist_emb_2.log └── trace.log ├── kaggle_dac_loss_accuracy_plots.png ├── mlperf_logger.py ├── optim └── rwsadagrad.py ├── requirements.txt ├── terabyte_0875_loss_accuracy_plots.png ├── test └── dlrm_s_test.sh ├── tools └── visualize.py ├── torchrec_dlrm ├── Dockerfile ├── README.MD ├── __init__.py ├── aws_component.py ├── data │ ├── __init__.py │ ├── dlrm_dataloader.py │ └── multi_hot_criteo.py ├── dlrm_main.py ├── lr_scheduler.py ├── md5sums_MLPerf_v2_synthetic_multi_hot_sparse_dataset.txt ├── md5sums_preprocessed_criteo_click_logs_dataset.txt ├── multi_hot.py ├── requirements.txt ├── scripts │ ├── download_Criteo_1TB_Click_Logs_dataset.sh │ ├── materialize_synthetic_multihot_dataset.py │ └── process_Criteo_1TB_Click_Logs_dataset.sh └── tests │ └── test_dlrm_main.py └── tricks ├── md_embedding_bag.py └── qr_embedding_bag.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. 4 | Please read the [full text](https://code.fb.com/codeofconduct/) 5 | so that you can understand what actions will and will not be tolerated. 6 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to DLRM 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## Coding Style 30 | * 4 spaces for indentation rather than tabs 31 | * 80 character line length 32 | * in general, please maintain a consistent style with the rest of the code 33 | 34 | ## License 35 | By contributing to DLRM, you agree that your contributions will be licensed 36 | under the LICENSE file in the root directory of this source tree. 37 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | ARG FROM_IMAGE_NAME=pytorch/pytorch:1.3-cuda10.1-cudnn7-runtime 7 | FROM ${FROM_IMAGE_NAME} 8 | 9 | ADD requirements.txt . 10 | RUN pip install -r requirements.txt 11 | 12 | RUN pip install torch==1.3.1 13 | 14 | WORKDIR /code 15 | ADD . . 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Facebook, Inc. and its affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Deep Learning Recommendation Model for Personalization and Recommendation Systems: 2 | ================================================================================= 3 | *Copyright (c) Facebook, Inc. and its affiliates.* 4 | 5 | Description: 6 | ------------ 7 | An implementation of a deep learning recommendation model (DLRM). 8 | The model input consists of dense and sparse features. The former is a vector 9 | of floating point values. The latter is a list of sparse indices into 10 | embedding tables, which consist of vectors of floating point values. 11 | The selected vectors are passed to mlp networks denoted by triangles, 12 | in some cases the vectors are interacted through operators (Ops). 13 | ``` 14 | output: 15 | probability of a click 16 | model: | 17 | /\ 18 | /__\ 19 | | 20 | _____________________> Op <___________________ 21 | / | \ 22 | /\ /\ /\ 23 | /__\ /__\ ... /__\ 24 | | | | 25 | | Op Op 26 | | ____/__\_____ ____/__\____ 27 | | |_Emb_|____|__| ... |_Emb_|__|___| 28 | input: 29 | [ dense features ] [sparse indices] , ..., [sparse indices] 30 | ``` 31 | More precise definition of model layers: 32 | 1) fully connected layers of an mlp 33 | 34 | z = f(y) 35 | 36 | y = Wx + b 37 | 38 | 2) embedding lookup (for a list of sparse indices p=[p1,...,pk]) 39 | 40 | z = Op(e1,...,ek) 41 | 42 | obtain vectors e1=E[:,p1], ..., ek=E[:,pk] 43 | 44 | 3) Operator Op can be one of the following 45 | 46 | Sum(e1,...,ek) = e1 + ... + ek 47 | 48 | Dot(e1,...,ek) = [e1'e1, ..., e1'ek, ..., ek'e1, ..., ek'ek] 49 | 50 | Cat(e1,...,ek) = [e1', ..., ek']' 51 | 52 | where ' denotes transpose operation 53 | 54 | See our blog post to learn more about DLRM: [https://ai.facebook.com/blog/dlrm-an-advanced-open-source-deep-learning-recommendation-model/](https://ai.facebook.com/blog/dlrm-an-advanced-open-source-deep-learning-recommendation-model/). 55 | 56 | Cite [Work](https://arxiv.org/abs/1906.00091): 57 | ``` 58 | @article{DLRM19, 59 | author = {Maxim Naumov and Dheevatsa Mudigere and Hao{-}Jun Michael Shi and Jianyu Huang and Narayanan Sundaraman and Jongsoo Park and Xiaodong Wang and Udit Gupta and Carole{-}Jean Wu and Alisson G. Azzolini and Dmytro Dzhulgakov and Andrey Mallevich and Ilia Cherniavskii and Yinghai Lu and Raghuraman Krishnamoorthi and Ansha Yu and Volodymyr Kondratenko and Stephanie Pereira and Xianjie Chen and Wenlin Chen and Vijay Rao and Bill Jia and Liang Xiong and Misha Smelyanskiy}, 60 | title = {Deep Learning Recommendation Model for Personalization and Recommendation Systems}, 61 | journal = {CoRR}, 62 | volume = {abs/1906.00091}, 63 | year = {2019}, 64 | url = {https://arxiv.org/abs/1906.00091}, 65 | } 66 | ``` 67 | 68 | Related Work: 69 | 70 | On the [system architecture implications](https://arxiv.org/abs/1906.03109), with DLRM as one of the benchmarks, 71 | ``` 72 | @article{ArchImpl19, 73 | author = {Udit Gupta and Xiaodong Wang and Maxim Naumov and Carole{-}Jean Wu and Brandon Reagen and David Brooks and Bradford Cottel and Kim M. Hazelwood and Bill Jia and Hsien{-}Hsin S. Lee and Andrey Malevich and Dheevatsa Mudigere and Mikhail Smelyanskiy and Liang Xiong and Xuan Zhang}, 74 | title = {The Architectural Implications of Facebook's DNN-based Personalized Recommendation}, 75 | journal = {CoRR}, 76 | volume = {abs/1906.03109}, 77 | year = {2019}, 78 | url = {https://arxiv.org/abs/1906.03109}, 79 | } 80 | ``` 81 | 82 | On the [embedding compression techniques (for number of vectors)](https://arxiv.org/abs/1909.02107), with DLRM as one of the benchmarks, 83 | ``` 84 | @article{QuoRemTrick19, 85 | author = {Hao{-}Jun Michael Shi and Dheevatsa Mudigere and Maxim Naumov and Jiyan Yang}, 86 | title = {Compositional Embeddings Using Complementary Partitions for Memory-Efficient Recommendation Systems}, 87 | journal = {CoRR}, 88 | volume = {abs/1909.02107}, 89 | year = {2019}, 90 | url = {https://arxiv.org/abs/1909.02107}, 91 | } 92 | ``` 93 | 94 | On the [embedding compression techniques (for dimension of vectors)](https://arxiv.org/abs/1909.11810), with DLRM as one of the benchmarks, 95 | ``` 96 | @article{MixDimTrick19, 97 | author = {Antonio Ginart and Maxim Naumov and Dheevatsa Mudigere and Jiyan Yang and James Zou}, 98 | title = {Mixed Dimension Embeddings with Application to Memory-Efficient Recommendation Systems}, 99 | journal = {CoRR}, 100 | volume = {abs/1909.11810}, 101 | year = {2019}, 102 | url = {https://arxiv.org/abs/1909.11810}, 103 | } 104 | ``` 105 | 106 | Implementation 107 | -------------- 108 | **DLRM PyTorch**. Implementation of DLRM in PyTorch framework: 109 | 110 | dlrm_s_pytorch.py 111 | 112 | **DLRM Caffe2**. Implementation of DLRM in Caffe2 framework: 113 | 114 | dlrm_s_caffe2.py 115 | 116 | **DLRM Data**. Implementation of DLRM data generation and loading: 117 | 118 | dlrm_data_pytorch.py, dlrm_data_caffe2.py, data_utils.py 119 | 120 | **DLRM Tests**. Implementation of DLRM tests in ./test 121 | 122 | dlrm_s_test.sh 123 | 124 | **DLRM Benchmarks**. Implementation of DLRM benchmarks in ./bench 125 | 126 | dlrm_s_criteo_kaggle.sh, dlrm_s_criteo_terabyte.sh, dlrm_s_benchmark.sh 127 | 128 | Related Work: 129 | 130 | On the [Glow framework](https://github.com/pytorch/glow) implementation 131 | ``` 132 | https://github.com/pytorch/glow/blob/master/tests/unittests/RecommendationSystemTest.cpp 133 | ``` 134 | On the [FlexFlow framework](https://github.com/flexflow/FlexFlow) distributed implementation with Legion backend 135 | ``` 136 | https://github.com/flexflow/FlexFlow/blob/master/examples/cpp/DLRM/dlrm.cc 137 | ``` 138 | 139 | How to run dlrm code? 140 | -------------------- 141 | 1) A sample run of the code, with a tiny model is shown below 142 | ``` 143 | $ python dlrm_s_pytorch.py --mini-batch-size=2 --data-size=6 144 | time/loss/accuracy (if enabled): 145 | Finished training it 1/3 of epoch 0, -1.00 ms/it, loss 0.451893, accuracy 0.000% 146 | Finished training it 2/3 of epoch 0, -1.00 ms/it, loss 0.402002, accuracy 0.000% 147 | Finished training it 3/3 of epoch 0, -1.00 ms/it, loss 0.275460, accuracy 0.000% 148 | ``` 149 | 2) A sample run of the code, with a tiny model in debug mode 150 | ``` 151 | $ python dlrm_s_pytorch.py --mini-batch-size=2 --data-size=6 --debug-mode 152 | model arch: 153 | mlp top arch 3 layers, with input to output dimensions: 154 | [8 4 2 1] 155 | # of interactions 156 | 8 157 | mlp bot arch 2 layers, with input to output dimensions: 158 | [4 3 2] 159 | # of features (sparse and dense) 160 | 4 161 | dense feature size 162 | 4 163 | sparse feature size 164 | 2 165 | # of embeddings (= # of sparse features) 3, with dimensions 2x: 166 | [4 3 2] 167 | data (inputs and targets): 168 | mini-batch: 0 169 | [[0.69647 0.28614 0.22685 0.55131] 170 | [0.71947 0.42311 0.98076 0.68483]] 171 | [[[1], [0, 1]], [[0], [1]], [[1], [0]]] 172 | [[0.55679] 173 | [0.15896]] 174 | mini-batch: 1 175 | [[0.36179 0.22826 0.29371 0.63098] 176 | [0.0921 0.4337 0.43086 0.49369]] 177 | [[[1], [0, 2, 3]], [[1], [1, 2]], [[1], [1]]] 178 | [[0.15307] 179 | [0.69553]] 180 | mini-batch: 2 181 | [[0.60306 0.54507 0.34276 0.30412] 182 | [0.41702 0.6813 0.87546 0.51042]] 183 | [[[2], [0, 1, 2]], [[1], [2]], [[1], [1]]] 184 | [[0.31877] 185 | [0.69197]] 186 | initial parameters (weights and bias): 187 | [[ 0.05438 -0.11105] 188 | [ 0.42513 0.34167] 189 | [-0.1426 -0.45641] 190 | [-0.19523 -0.10181]] 191 | [[ 0.23667 0.57199] 192 | [-0.16638 0.30316] 193 | [ 0.10759 0.22136]] 194 | [[-0.49338 -0.14301] 195 | [-0.36649 -0.22139]] 196 | [[0.51313 0.66662 0.10591 0.13089] 197 | [0.32198 0.66156 0.84651 0.55326] 198 | [0.85445 0.38484 0.31679 0.35426]] 199 | [0.17108 0.82911 0.33867] 200 | [[0.55237 0.57855 0.52153] 201 | [0.00269 0.98835 0.90534]] 202 | [0.20764 0.29249] 203 | [[0.52001 0.90191 0.98363 0.25754 0.56436 0.80697 0.39437 0.73107] 204 | [0.16107 0.6007 0.86586 0.98352 0.07937 0.42835 0.20454 0.45064] 205 | [0.54776 0.09333 0.29686 0.92758 0.569 0.45741 0.75353 0.74186] 206 | [0.04858 0.7087 0.83924 0.16594 0.781 0.28654 0.30647 0.66526]] 207 | [0.11139 0.66487 0.88786 0.69631] 208 | [[0.44033 0.43821 0.7651 0.56564] 209 | [0.0849 0.58267 0.81484 0.33707]] 210 | [0.92758 0.75072] 211 | [[0.57406 0.75164]] 212 | [0.07915] 213 | DLRM_Net( 214 | (emb_l): ModuleList( 215 | (0): EmbeddingBag(4, 2, mode=sum) 216 | (1): EmbeddingBag(3, 2, mode=sum) 217 | (2): EmbeddingBag(2, 2, mode=sum) 218 | ) 219 | (bot_l): Sequential( 220 | (0): Linear(in_features=4, out_features=3, bias=True) 221 | (1): ReLU() 222 | (2): Linear(in_features=3, out_features=2, bias=True) 223 | (3): ReLU() 224 | ) 225 | (top_l): Sequential( 226 | (0): Linear(in_features=8, out_features=4, bias=True) 227 | (1): ReLU() 228 | (2): Linear(in_features=4, out_features=2, bias=True) 229 | (3): ReLU() 230 | (4): Linear(in_features=2, out_features=1, bias=True) 231 | (5): Sigmoid() 232 | ) 233 | ) 234 | time/loss/accuracy (if enabled): 235 | Finished training it 1/3 of epoch 0, -1.00 ms/it, loss 0.451893, accuracy 0.000% 236 | Finished training it 2/3 of epoch 0, -1.00 ms/it, loss 0.402002, accuracy 0.000% 237 | Finished training it 3/3 of epoch 0, -1.00 ms/it, loss 0.275460, accuracy 0.000% 238 | updated parameters (weights and bias): 239 | [[ 0.0543 -0.1112 ] 240 | [ 0.42513 0.34167] 241 | [-0.14283 -0.45679] 242 | [-0.19532 -0.10197]] 243 | [[ 0.23667 0.57199] 244 | [-0.1666 0.30285] 245 | [ 0.10751 0.22124]] 246 | [[-0.49338 -0.14301] 247 | [-0.36664 -0.22164]] 248 | [[0.51313 0.66663 0.10591 0.1309 ] 249 | [0.32196 0.66154 0.84649 0.55324] 250 | [0.85444 0.38482 0.31677 0.35425]] 251 | [0.17109 0.82907 0.33863] 252 | [[0.55238 0.57857 0.52154] 253 | [0.00265 0.98825 0.90528]] 254 | [0.20764 0.29244] 255 | [[0.51996 0.90184 0.98368 0.25752 0.56436 0.807 0.39437 0.73107] 256 | [0.16096 0.60055 0.86596 0.98348 0.07938 0.42842 0.20453 0.45064] 257 | [0.5476 0.0931 0.29701 0.92752 0.56902 0.45752 0.75351 0.74187] 258 | [0.04849 0.70857 0.83933 0.1659 0.78101 0.2866 0.30646 0.66526]] 259 | [0.11137 0.66482 0.88778 0.69627] 260 | [[0.44029 0.43816 0.76502 0.56561] 261 | [0.08485 0.5826 0.81474 0.33702]] 262 | [0.92754 0.75067] 263 | [[0.57379 0.7514 ]] 264 | [0.07908] 265 | ``` 266 | 267 | Testing 268 | ------- 269 | Testing scripts to confirm functional correctness of the code 270 | ``` 271 | ./test/dlrm_s_test.sh 272 | Running commands ... 273 | python dlrm_s_pytorch.py 274 | python dlrm_s_caffe2.py 275 | Checking results ... 276 | diff test1 (no numeric values in the output = SUCCESS) 277 | diff test2 (no numeric values in the output = SUCCESS) 278 | diff test3 (no numeric values in the output = SUCCESS) 279 | diff test4 (no numeric values in the output = SUCCESS) 280 | ``` 281 | 282 | *NOTE: Testing scripts accept extra arguments which will be passed along to the model, such as --use-gpu* 283 | 284 | Benchmarking 285 | ------------ 286 | 1) Performance benchmarking 287 | ``` 288 | ./bench/dlrm_s_benchmark.sh 289 | ``` 290 | 291 | 2) The code supports interface with the [Criteo Kaggle Display Advertising Challenge Dataset](https://ailab.criteo.com/ressources/). 292 | - Please do the following to prepare the dataset for use with DLRM code: 293 | - First, specify the raw data file (train.txt) as downloaded with --raw-data-file= 294 | - This is then pre-processed (categorize, concat across days...) to allow using with dlrm code 295 | - The processed data is stored as *.npz file in /input/*.npz 296 | - The processed file (*.npz) can be used for subsequent runs with --processed-data-file= 297 | - The model can be trained using the following script 298 | ``` 299 | ./bench/dlrm_s_criteo_kaggle.sh [--test-freq=1024] 300 | ``` 301 | 302 | 303 | 304 | 3) The code supports interface with the [Criteo Terabyte Dataset](https://labs.criteo.com/2013/12/download-terabyte-click-logs/). 305 | - Please do the following to prepare the dataset for use with DLRM code: 306 | - First, download the raw data files day_0.gz, ...,day_23.gz and unzip them 307 | - Specify the location of the unzipped text files day_0, ...,day_23, using --raw-data-file= (the day number will be appended automatically) 308 | - These are then pre-processed (categorize, concat across days...) to allow using with dlrm code 309 | - The processed data is stored as *.npz file in /input/*.npz 310 | - The processed file (*.npz) can be used for subsequent runs with --processed-data-file= 311 | - The model can be trained using the following script 312 | ``` 313 | ./bench/dlrm_s_criteo_terabyte.sh ["--test-freq=10240 --memory-map --data-sub-sample-rate=0.875"] 314 | ``` 315 | - Corresponding pre-trained model is available under [CC-BY-NC license](https://creativecommons.org/licenses/by-nc/2.0/) and can be downloaded here 316 | [dlrm_emb64_subsample0.875_maxindrange10M_pretrained.pt](https://dlrm.s3-us-west-1.amazonaws.com/models/tb0875_10M.pt) 317 | 318 | 319 | 320 | *NOTE: Benchmarking scripts accept extra arguments which will be passed along to the model, such as --num-batches=100 to limit the number of data samples* 321 | 322 | 4) The code supports interface with [MLPerf benchmark](https://mlperf.org). 323 | - Please refer to the following training parameters 324 | ``` 325 | --mlperf-logging that keeps track of multiple metrics, including area under the curve (AUC) 326 | 327 | --mlperf-acc-threshold that allows early stopping based on accuracy metric 328 | 329 | --mlperf-auc-threshold that allows early stopping based on AUC metric 330 | 331 | --mlperf-bin-loader that enables preprocessing of data into a single binary file 332 | 333 | --mlperf-bin-shuffle that controls whether a random shuffle of mini-batches is performed 334 | ``` 335 | - The MLPerf training model is completely specified and can be trained using the following script 336 | ``` 337 | ./bench/run_and_time.sh [--use-gpu] 338 | ``` 339 | - Corresponding pre-trained model is available under [CC-BY-NC license](https://creativecommons.org/licenses/by-nc/2.0/) and can be downloaded here 340 | [dlrm_emb128_subsample0.0_maxindrange40M_pretrained.pt](https://dlrm.s3-us-west-1.amazonaws.com/models/tb00_40M.pt) 341 | 342 | 5) The code now supports synchronous distributed training, we support gloo/nccl/mpi backend, we provide launching mode for [pytorch distributed launcher](https://pytorch.org/docs/stable/distributed.html#launch-utility) and Mpirun. For MPI, users need to write their own MPI launching scripts for configuring the running hosts. For example, using pytorch distributed launcher, we can have the following command as launching scripts: 343 | ``` 344 | # for single node 8 gpus and nccl as backend on randomly generated dataset: 345 | python -m torch.distributed.launch --nproc_per_node=8 dlrm_s_pytorch.py --arch-embedding-size="80000-80000-80000-80000-80000-80000-80000-80000" --arch-sparse-feature-size=64 --arch-mlp-bot="128-128-128-128" --arch-mlp-top="512-512-512-256-1" --max-ind-range=40000000 346 | --data-generation=random --loss-function=bce --round-targets=True --learning-rate=1.0 --mini-batch-size=2048 --print-freq=2 --print-time --test-freq=2 --test-mini-batch-size=2048 --memory-map --use-gpu --num-batches=100 --dist-backend=nccl 347 | 348 | # for multiple nodes, user can add the related argument according to the launcher manual like: 349 | --nnodes=2 --node_rank=0 --master_addr="192.168.1.1" --master_port=1234 350 | ``` 351 | 352 | 353 | Model checkpoint saving/loading 354 | ------------------------------- 355 | During training, the model can be saved using --save-model= 356 | 357 | The model is saved if there is an improvement in test accuracy (which is checked at --test-freq intervals). 358 | 359 | A previously saved model can be loaded using --load-model= 360 | 361 | Once loaded the model can be used to continue training, with the saved model being a checkpoint. 362 | Alternatively, the saved model can be used to evaluate only on the test data-set by specifying --inference-only option. 363 | 364 | 365 | Version 366 | ------- 367 | 0.1 : Initial release of the DLRM code 368 | 369 | 1.0 : DLRM with distributed training, cpu support for row-wise adagrad optimizer 370 | 371 | Requirements 372 | ------------ 373 | pytorch-nightly (*11/10/20*) 374 | 375 | scikit-learn 376 | 377 | numpy 378 | 379 | onnx (*optional*) 380 | 381 | pydot (*optional*) 382 | 383 | torchviz (*optional*) 384 | 385 | mpi (*optional for distributed backend*) 386 | 387 | 388 | License 389 | ------- 390 | This source code is licensed under the MIT license found in the 391 | LICENSE file in the root directory of this source tree. 392 | -------------------------------------------------------------------------------- /bench/dlrm_s_benchmark.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | #check if extra argument is passed to the test 8 | if [[ $# == 1 ]]; then 9 | dlrm_extra_option=$1 10 | else 11 | dlrm_extra_option="" 12 | fi 13 | #echo $dlrm_extra_option 14 | 15 | cpu=1 16 | gpu=1 17 | pt=1 18 | c2=1 19 | 20 | ncores=28 #12 #6 21 | nsockets="0" 22 | 23 | ngpus="1 2 4 8" 24 | 25 | numa_cmd="numactl --physcpubind=0-$((ncores-1)) -m $nsockets" #run on one socket, without HT 26 | dlrm_pt_bin="python dlrm_s_pytorch.py" 27 | dlrm_c2_bin="python dlrm_s_caffe2.py" 28 | 29 | data=random #synthetic 30 | print_freq=100 31 | rand_seed=727 32 | 33 | c2_net="async_scheduling" 34 | 35 | #Model param 36 | mb_size=2048 #1024 #512 #256 37 | nbatches=1000 #500 #100 38 | bot_mlp="512-512-64" 39 | top_mlp="1024-1024-1024-1" 40 | emb_size=64 41 | nindices=100 42 | emb="1000000-1000000-1000000-1000000-1000000-1000000-1000000-1000000" 43 | interaction="dot" 44 | tnworkers=0 45 | tmb_size=16384 46 | 47 | #_args="--mini-batch-size="${mb_size}\ 48 | _args=" --num-batches="${nbatches}\ 49 | " --data-generation="${data}\ 50 | " --arch-mlp-bot="${bot_mlp}\ 51 | " --arch-mlp-top="${top_mlp}\ 52 | " --arch-sparse-feature-size="${emb_size}\ 53 | " --arch-embedding-size="${emb}\ 54 | " --num-indices-per-lookup="${nindices}\ 55 | " --arch-interaction-op="${interaction}\ 56 | " --numpy-rand-seed="${rand_seed}\ 57 | " --print-freq="${print_freq}\ 58 | " --print-time"\ 59 | " --enable-profiling " 60 | 61 | c2_args=" --caffe2-net-type="${c2_net} 62 | 63 | 64 | # CPU Benchmarking 65 | if [ $cpu = 1 ]; then 66 | echo "--------------------------------------------" 67 | echo "CPU Benchmarking - running on $ncores cores" 68 | echo "--------------------------------------------" 69 | if [ $pt = 1 ]; then 70 | outf="model1_CPU_PT_$ncores.log" 71 | outp="dlrm_s_pytorch.prof" 72 | echo "-------------------------------" 73 | echo "Running PT (log file: $outf)" 74 | echo "-------------------------------" 75 | cmd="$numa_cmd $dlrm_pt_bin --mini-batch-size=$mb_size --test-mini-batch-size=$tmb_size --test-num-workers=$tnworkers $_args $dlrm_extra_option > $outf" 76 | echo $cmd 77 | eval $cmd 78 | min=$(grep "iteration" $outf | awk 'BEGIN{best=999999} {if (best > $7) best=$7} END{print best}') 79 | echo "Min time per iteration = $min" 80 | # move profiling file(s) 81 | mv $outp ${outf//".log"/".prof"} 82 | mv ${outp//".prof"/".json"} ${outf//".log"/".json"} 83 | 84 | fi 85 | if [ $c2 = 1 ]; then 86 | outf="model1_CPU_C2_$ncores.log" 87 | outp="dlrm_s_caffe2.prof" 88 | echo "-------------------------------" 89 | echo "Running C2 (log file: $outf)" 90 | echo "-------------------------------" 91 | cmd="$numa_cmd $dlrm_c2_bin --mini-batch-size=$mb_size $_args $c2_args $dlrm_extra_option 1> $outf 2> $outp" 92 | echo $cmd 93 | eval $cmd 94 | min=$(grep "iteration" $outf | awk 'BEGIN{best=999999} {if (best > $7) best=$7} END{print best}') 95 | echo "Min time per iteration = $min" 96 | # move profiling file (collected from stderr above) 97 | mv $outp ${outf//".log"/".prof"} 98 | fi 99 | fi 100 | 101 | # GPU Benchmarking 102 | if [ $gpu = 1 ]; then 103 | echo "--------------------------------------------" 104 | echo "GPU Benchmarking - running on $ngpus GPUs" 105 | echo "--------------------------------------------" 106 | for _ng in $ngpus 107 | do 108 | # weak scaling 109 | # _mb_size=$((mb_size*_ng)) 110 | # strong scaling 111 | _mb_size=$((mb_size*1)) 112 | _gpus=$(seq -s, 0 $((_ng-1))) 113 | cuda_arg="CUDA_VISIBLE_DEVICES=$_gpus" 114 | echo "-------------------" 115 | echo "Using GPUS: "$_gpus 116 | echo "-------------------" 117 | if [ $pt = 1 ]; then 118 | outf="model1_GPU_PT_$_ng.log" 119 | outp="dlrm_s_pytorch.prof" 120 | echo "-------------------------------" 121 | echo "Running PT (log file: $outf)" 122 | echo "-------------------------------" 123 | cmd="$cuda_arg $dlrm_pt_bin --mini-batch-size=$_mb_size --test-mini-batch-size=$tmb_size --test-num-workers=$tnworkers $_args --use-gpu $dlrm_extra_option > $outf" 124 | echo $cmd 125 | eval $cmd 126 | min=$(grep "iteration" $outf | awk 'BEGIN{best=999999} {if (best > $7) best=$7} END{print best}') 127 | echo "Min time per iteration = $min" 128 | # move profiling file(s) 129 | mv $outp ${outf//".log"/".prof"} 130 | mv ${outp//".prof"/".json"} ${outf//".log"/".json"} 131 | fi 132 | if [ $c2 = 1 ]; then 133 | outf="model1_GPU_C2_$_ng.log" 134 | outp="dlrm_s_caffe2.prof" 135 | echo "-------------------------------" 136 | echo "Running C2 (log file: $outf)" 137 | echo "-------------------------------" 138 | cmd="$cuda_arg $dlrm_c2_bin --mini-batch-size=$_mb_size $_args $c2_args --use-gpu $dlrm_extra_option 1> $outf 2> $outp" 139 | echo $cmd 140 | eval $cmd 141 | min=$(grep "iteration" $outf | awk 'BEGIN{best=999999} {if (best > $7) best=$7} END{print best}') 142 | echo "Min time per iteration = $min" 143 | # move profiling file (collected from stderr above) 144 | mv $outp ${outf//".log"/".prof"} 145 | fi 146 | done 147 | fi 148 | -------------------------------------------------------------------------------- /bench/dlrm_s_criteo_kaggle.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | #WARNING: must have compiled PyTorch and caffe2 8 | 9 | #check if extra argument is passed to the test 10 | if [[ $# == 1 ]]; then 11 | dlrm_extra_option=$1 12 | else 13 | dlrm_extra_option="" 14 | fi 15 | #echo $dlrm_extra_option 16 | 17 | dlrm_pt_bin="python dlrm_s_pytorch.py" 18 | dlrm_c2_bin="python dlrm_s_caffe2.py" 19 | 20 | echo "run pytorch ..." 21 | # WARNING: the following parameters will be set based on the data set 22 | # --arch-embedding-size=... (sparse feature sizes) 23 | # --arch-mlp-bot=... (the input to the first layer of bottom mlp) 24 | $dlrm_pt_bin --arch-sparse-feature-size=16 --arch-mlp-bot="13-512-256-64-16" --arch-mlp-top="512-256-1" --data-generation=dataset --data-set=kaggle --raw-data-file=./input/train.txt --processed-data-file=./input/kaggleAdDisplayChallenge_processed.npz --loss-function=bce --round-targets=True --learning-rate=0.1 --mini-batch-size=128 --print-freq=1024 --print-time --test-mini-batch-size=16384 --test-num-workers=16 $dlrm_extra_option 2>&1 | tee run_kaggle_pt.log 25 | 26 | echo "run caffe2 ..." 27 | # WARNING: the following parameters will be set based on the data set 28 | # --arch-embedding-size=... (sparse feature sizes) 29 | # --arch-mlp-bot=... (the input to the first layer of bottom mlp) 30 | $dlrm_c2_bin --arch-sparse-feature-size=16 --arch-mlp-bot="13-512-256-64-16" --arch-mlp-top="512-256-1" --data-generation=dataset --data-set=kaggle --raw-data-file=./input/train.txt --processed-data-file=./input/kaggleAdDisplayChallenge_processed.npz --loss-function=bce --round-targets=True --learning-rate=0.1 --mini-batch-size=128 --print-freq=1024 --print-time $dlrm_extra_option 2>&1 | tee run_kaggle_c2.log 31 | 32 | echo "done" 33 | -------------------------------------------------------------------------------- /bench/dlrm_s_criteo_terabyte.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | #WARNING: must have compiled PyTorch and caffe2 8 | 9 | #check if extra argument is passed to the test 10 | if [[ $# == 1 ]]; then 11 | dlrm_extra_option=$1 12 | else 13 | dlrm_extra_option="" 14 | fi 15 | #echo $dlrm_extra_option 16 | 17 | dlrm_pt_bin="python dlrm_s_pytorch.py" 18 | dlrm_c2_bin="python dlrm_s_caffe2.py" 19 | 20 | echo "run pytorch ..." 21 | # WARNING: the following parameters will be set based on the data set 22 | # --arch-embedding-size=... (sparse feature sizes) 23 | # --arch-mlp-bot=... (the input to the first layer of bottom mlp) 24 | $dlrm_pt_bin --arch-sparse-feature-size=64 --arch-mlp-bot="13-512-256-64" --arch-mlp-top="512-512-256-1" --max-ind-range=10000000 --data-generation=dataset --data-set=terabyte --raw-data-file=./input/day --processed-data-file=./input/terabyte_processed.npz --loss-function=bce --round-targets=True --learning-rate=0.1 --mini-batch-size=2048 --print-freq=1024 --print-time --test-mini-batch-size=16384 --test-num-workers=16 $dlrm_extra_option 2>&1 | tee run_terabyte_pt.log 25 | 26 | echo "run caffe2 ..." 27 | # WARNING: the following parameters will be set based on the data set 28 | # --arch-embedding-size=... (sparse feature sizes) 29 | # --arch-mlp-bot=... (the input to the first layer of bottom mlp) 30 | $dlrm_c2_bin --arch-sparse-feature-size=64 --arch-mlp-bot="13-512-256-64" --arch-mlp-top="512-512-256-1" --max-ind-range=10000000 --data-generation=dataset --data-set=terabyte --raw-data-file=./input/day --processed-data-file=./input/terabyte_processed.npz --loss-function=bce --round-targets=True --learning-rate=0.1 --mini-batch-size=2048 --print-freq=1024 --print-time $dlrm_extra_option 2>&1 | tee run_terabyte_c2.log 31 | 32 | echo "done" 33 | -------------------------------------------------------------------------------- /bench/run_and_time.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | #WARNING: must have compiled PyTorch and caffe2 8 | 9 | #check if extra argument is passed to the test 10 | if [[ $# == 1 ]]; then 11 | dlrm_extra_option=$1 12 | else 13 | dlrm_extra_option="" 14 | fi 15 | #echo $dlrm_extra_option 16 | 17 | python dlrm_s_pytorch.py --arch-sparse-feature-size=128 --arch-mlp-bot="13-512-256-128" --arch-mlp-top="1024-1024-512-256-1" --max-ind-range=40000000 --data-generation=dataset --data-set=terabyte --raw-data-file=./input/day --processed-data-file=./input/terabyte_processed.npz --loss-function=bce --round-targets=True --learning-rate=1.0 --mini-batch-size=2048 --print-freq=2048 --print-time --test-freq=102400 --test-mini-batch-size=16384 --test-num-workers=16 --memory-map --mlperf-logging --mlperf-auc-threshold=0.8025 --mlperf-bin-loader --mlperf-bin-shuffle $dlrm_extra_option 2>&1 | tee run_terabyte_mlperf_pt.log 18 | 19 | echo "done" 20 | -------------------------------------------------------------------------------- /cython/cython_compile.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # Description: compile .so from python code 7 | 8 | from __future__ import absolute_import, division, print_function, unicode_literals 9 | 10 | from distutils.extension import Extension 11 | 12 | from Cython.Build import cythonize 13 | 14 | from setuptools import setup 15 | 16 | ext_modules = [ 17 | Extension( 18 | "data_utils_cython", 19 | ["data_utils_cython.pyx"], 20 | extra_compile_args=["-O3"], 21 | extra_link_args=["-O3"], 22 | ) 23 | ] 24 | 25 | setup(name="data_utils_cython", ext_modules=cythonize(ext_modules)) 26 | -------------------------------------------------------------------------------- /cython/cython_criteo.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # Description: run dataset pre-processing in standalone mode 7 | # WARNING: These steps are required to work with Cython 8 | # 1. Instal Cython 9 | # > sudo yum install Cython 10 | # 2. Please copy data_utils.py into data_utils_cython.pyx 11 | # 3. Compile the data_utils_cython.pyx to generate .so 12 | # (it's important to keep extension .pyx rather than .py 13 | # to ensure the C/C++ .so no .py is loaded at import time) 14 | # > python cython_compile.py build_ext --inplace 15 | # This should create data_utils_cython.so, which can be loaded below with "import" 16 | # 4. Run standalone datatset preprocessing to generate .npz files 17 | # a. Kaggle 18 | # > python cython_criteo.py --data-set=kaggle --raw-data-file=./input/train.txt 19 | # --processed-data-file=./input/kaggleAdDisplayChallenge_processed.npz 20 | # b. Terabyte 21 | # > python cython_criteo.py --max-ind-range=10000000 [--memory-map] --data-set=terabyte 22 | # --raw-data-file=./input/day --processed-data-file=./input/terabyte_processed.npz 23 | 24 | from __future__ import absolute_import, division, print_function, unicode_literals 25 | 26 | import data_utils_cython as duc 27 | 28 | if __name__ == "__main__": 29 | ### import packages ### 30 | import argparse 31 | 32 | ### parse arguments ### 33 | parser = argparse.ArgumentParser(description="Preprocess Criteo dataset") 34 | # model related parameters 35 | parser.add_argument("--max-ind-range", type=int, default=-1) 36 | parser.add_argument("--data-sub-sample-rate", type=float, default=0.0) # in [0, 1] 37 | parser.add_argument("--data-randomize", type=str, default="total") # or day or none 38 | parser.add_argument("--memory-map", action="store_true", default=False) 39 | parser.add_argument("--data-set", type=str, default="kaggle") # or terabyte 40 | parser.add_argument("--raw-data-file", type=str, default="") 41 | parser.add_argument("--processed-data-file", type=str, default="") 42 | args = parser.parse_args() 43 | 44 | duc.loadDataset( 45 | args.data_set, 46 | args.max_ind_range, 47 | args.data_sub_sample_rate, 48 | args.data_randomize, 49 | "train", 50 | args.raw_data_file, 51 | args.processed_data_file, 52 | args.memory_map, 53 | ) 54 | -------------------------------------------------------------------------------- /data_loader_terabyte.py: -------------------------------------------------------------------------------- 1 | # @lint-ignore-every LICENSELINT 2 | 3 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | from __future__ import absolute_import, division, print_function, unicode_literals 10 | 11 | import argparse 12 | import math 13 | 14 | import os 15 | import time 16 | 17 | import numpy as np 18 | import torch 19 | from torch.utils.data import Dataset 20 | from tqdm import tqdm 21 | 22 | 23 | class DataLoader: 24 | """ 25 | DataLoader dedicated for the Criteo Terabyte Click Logs dataset 26 | """ 27 | 28 | def __init__( 29 | self, 30 | data_filename, 31 | data_directory, 32 | days, 33 | batch_size, 34 | max_ind_range=-1, 35 | split="train", 36 | drop_last_batch=False, 37 | ): 38 | self.data_filename = data_filename 39 | self.data_directory = data_directory 40 | self.days = days 41 | self.batch_size = batch_size 42 | self.max_ind_range = max_ind_range 43 | 44 | total_file = os.path.join(data_directory, data_filename + "_day_count.npz") 45 | with np.load(total_file) as data: 46 | total_per_file = data["total_per_file"][np.array(days)] 47 | 48 | self.length = sum(total_per_file) 49 | if split == "test" or split == "val": 50 | self.length = int(np.ceil(self.length / 2.0)) 51 | self.split = split 52 | self.drop_last_batch = drop_last_batch 53 | 54 | def __iter__(self): 55 | return iter( 56 | _batch_generator( 57 | self.data_filename, 58 | self.data_directory, 59 | self.days, 60 | self.batch_size, 61 | self.split, 62 | self.drop_last_batch, 63 | self.max_ind_range, 64 | ) 65 | ) 66 | 67 | def __len__(self): 68 | if self.drop_last_batch: 69 | return self.length // self.batch_size 70 | else: 71 | return math.ceil(self.length / self.batch_size) 72 | 73 | 74 | def _transform_features( 75 | x_int_batch, x_cat_batch, y_batch, max_ind_range, flag_input_torch_tensor=False 76 | ): 77 | if max_ind_range > 0: 78 | x_cat_batch = x_cat_batch % max_ind_range 79 | 80 | if flag_input_torch_tensor: 81 | x_int_batch = torch.log(x_int_batch.clone().detach().type(torch.float) + 1) 82 | x_cat_batch = x_cat_batch.clone().detach().type(torch.long) 83 | y_batch = y_batch.clone().detach().type(torch.float32).view(-1, 1) 84 | else: 85 | x_int_batch = torch.log(torch.tensor(x_int_batch, dtype=torch.float) + 1) 86 | x_cat_batch = torch.tensor(x_cat_batch, dtype=torch.long) 87 | y_batch = torch.tensor(y_batch, dtype=torch.float32).view(-1, 1) 88 | 89 | batch_size = x_cat_batch.shape[0] 90 | feature_count = x_cat_batch.shape[1] 91 | lS_o = torch.arange(batch_size).reshape(1, -1).repeat(feature_count, 1) 92 | 93 | return x_int_batch, lS_o, x_cat_batch.t(), y_batch.view(-1, 1) 94 | 95 | 96 | def _batch_generator( 97 | data_filename, data_directory, days, batch_size, split, drop_last, max_ind_range 98 | ): 99 | previous_file = None 100 | for day in days: 101 | filepath = os.path.join( 102 | data_directory, data_filename + "_{}_reordered.npz".format(day) 103 | ) 104 | 105 | # print('Loading file: ', filepath) 106 | with np.load(filepath) as data: 107 | x_int = data["X_int"] 108 | x_cat = data["X_cat"] 109 | y = data["y"] 110 | 111 | samples_in_file = y.shape[0] 112 | batch_start_idx = 0 113 | if split == "test" or split == "val": 114 | length = int(np.ceil(samples_in_file / 2.0)) 115 | if split == "test": 116 | samples_in_file = length 117 | elif split == "val": 118 | batch_start_idx = samples_in_file - length 119 | 120 | while batch_start_idx < samples_in_file - batch_size: 121 | missing_samples = batch_size 122 | if previous_file is not None: 123 | missing_samples -= previous_file["y"].shape[0] 124 | 125 | current_slice = slice(batch_start_idx, batch_start_idx + missing_samples) 126 | 127 | x_int_batch = x_int[current_slice] 128 | x_cat_batch = x_cat[current_slice] 129 | y_batch = y[current_slice] 130 | 131 | if previous_file is not None: 132 | x_int_batch = np.concatenate( 133 | [previous_file["x_int"], x_int_batch], axis=0 134 | ) 135 | x_cat_batch = np.concatenate( 136 | [previous_file["x_cat"], x_cat_batch], axis=0 137 | ) 138 | y_batch = np.concatenate([previous_file["y"], y_batch], axis=0) 139 | previous_file = None 140 | 141 | if x_int_batch.shape[0] != batch_size: 142 | raise ValueError("should not happen") 143 | 144 | yield _transform_features(x_int_batch, x_cat_batch, y_batch, max_ind_range) 145 | 146 | batch_start_idx += missing_samples 147 | if batch_start_idx != samples_in_file: 148 | current_slice = slice(batch_start_idx, samples_in_file) 149 | if previous_file is not None: 150 | previous_file = { 151 | "x_int": np.concatenate( 152 | [previous_file["x_int"], x_int[current_slice]], axis=0 153 | ), 154 | "x_cat": np.concatenate( 155 | [previous_file["x_cat"], x_cat[current_slice]], axis=0 156 | ), 157 | "y": np.concatenate([previous_file["y"], y[current_slice]], axis=0), 158 | } 159 | else: 160 | previous_file = { 161 | "x_int": x_int[current_slice], 162 | "x_cat": x_cat[current_slice], 163 | "y": y[current_slice], 164 | } 165 | 166 | if not drop_last: 167 | yield _transform_features( 168 | previous_file["x_int"], 169 | previous_file["x_cat"], 170 | previous_file["y"], 171 | max_ind_range, 172 | ) 173 | 174 | 175 | def _test(): 176 | generator = _batch_generator( 177 | data_filename="day", 178 | data_directory="./input", 179 | days=range(23), 180 | split="train", 181 | batch_size=2048, 182 | drop_last=True, 183 | max_ind_range=-1, 184 | ) 185 | t1 = time.time() 186 | for x_int, lS_o, x_cat, y in generator: 187 | t2 = time.time() 188 | time_diff = t2 - t1 189 | t1 = t2 190 | print( 191 | "time {} x_int.shape: {} lS_o.shape: {} x_cat.shape: {} y.shape: {}".format( 192 | time_diff, x_int.shape, lS_o.shape, x_cat.shape, y.shape 193 | ) 194 | ) 195 | 196 | 197 | class CriteoBinDataset(Dataset): 198 | """Binary version of criteo dataset.""" 199 | 200 | def __init__( 201 | self, 202 | data_file, 203 | counts_file, 204 | batch_size=1, 205 | max_ind_range=-1, 206 | bytes_per_feature=4, 207 | ): 208 | # dataset 209 | self.tar_fea = 1 # single target 210 | self.den_fea = 13 # 13 dense features 211 | self.spa_fea = 26 # 26 sparse features 212 | self.tad_fea = self.tar_fea + self.den_fea 213 | self.tot_fea = self.tad_fea + self.spa_fea 214 | 215 | self.batch_size = batch_size 216 | self.max_ind_range = max_ind_range 217 | self.bytes_per_entry = bytes_per_feature * self.tot_fea * batch_size 218 | 219 | self.num_entries = math.ceil(os.path.getsize(data_file) / self.bytes_per_entry) 220 | 221 | print("data file:", data_file, "number of batches:", self.num_entries) 222 | self.file = open(data_file, "rb") 223 | 224 | with np.load(counts_file) as data: 225 | self.counts = data["counts"] 226 | 227 | # hardcoded for now 228 | self.m_den = 13 229 | 230 | def __len__(self): 231 | return self.num_entries 232 | 233 | def __getitem__(self, idx): 234 | self.file.seek(idx * self.bytes_per_entry, 0) 235 | raw_data = self.file.read(self.bytes_per_entry) 236 | array = np.frombuffer(raw_data, dtype=np.int32) 237 | tensor = torch.from_numpy(array).view((-1, self.tot_fea)) 238 | 239 | return _transform_features( 240 | x_int_batch=tensor[:, 1:14], 241 | x_cat_batch=tensor[:, 14:], 242 | y_batch=tensor[:, 0], 243 | max_ind_range=self.max_ind_range, 244 | flag_input_torch_tensor=True, 245 | ) 246 | 247 | def __del__(self): 248 | self.file.close() 249 | 250 | 251 | def numpy_to_binary(input_files, output_file_path, split="train"): 252 | """Convert the data to a binary format to be read with CriteoBinDataset.""" 253 | 254 | # WARNING - both categorical and numerical data must fit into int32 for 255 | # the following code to work correctly 256 | 257 | with open(output_file_path, "wb") as output_file: 258 | if split == "train": 259 | for input_file in input_files: 260 | print("Processing file: ", input_file) 261 | 262 | np_data = np.load(input_file) 263 | np_data = np.concatenate( 264 | [np_data["y"].reshape(-1, 1), np_data["X_int"], np_data["X_cat"]], 265 | axis=1, 266 | ) 267 | np_data = np_data.astype(np.int32) 268 | 269 | output_file.write(np_data.tobytes()) 270 | else: 271 | assert len(input_files) == 1 272 | np_data = np.load(input_files[0]) 273 | np_data = np.concatenate( 274 | [np_data["y"].reshape(-1, 1), np_data["X_int"], np_data["X_cat"]], 275 | axis=1, 276 | ) 277 | np_data = np_data.astype(np.int32) 278 | 279 | samples_in_file = np_data.shape[0] 280 | midpoint = int(np.ceil(samples_in_file / 2.0)) 281 | if split == "test": 282 | begin = 0 283 | end = midpoint 284 | elif split == "val": 285 | begin = midpoint 286 | end = samples_in_file 287 | else: 288 | raise ValueError("Unknown split value: ", split) 289 | 290 | output_file.write(np_data[begin:end].tobytes()) 291 | 292 | 293 | def _preprocess(args): 294 | train_files = [ 295 | "{}_{}_reordered.npz".format(args.input_data_prefix, day) 296 | for day in range(0, 23) 297 | ] 298 | 299 | test_valid_file = args.input_data_prefix + "_23_reordered.npz" 300 | 301 | os.makedirs(args.output_directory, exist_ok=True) 302 | for split in ["train", "val", "test"]: 303 | print("Running preprocessing for split =", split) 304 | 305 | output_file = os.path.join(args.output_directory, "{}_data.bin".format(split)) 306 | 307 | input_files = train_files if split == "train" else [test_valid_file] 308 | numpy_to_binary( 309 | input_files=input_files, output_file_path=output_file, split=split 310 | ) 311 | 312 | 313 | def _test_bin(): 314 | parser = argparse.ArgumentParser() 315 | parser.add_argument("--output_directory", required=True) 316 | parser.add_argument("--input_data_prefix", required=True) 317 | parser.add_argument("--split", choices=["train", "test", "val"], required=True) 318 | args = parser.parse_args() 319 | 320 | _preprocess(args) 321 | 322 | binary_data_file = os.path.join( 323 | args.output_directory, "{}_data.bin".format(args.split) 324 | ) 325 | 326 | counts_file = os.path.join(args.output_directory, "day_fea_count.npz") 327 | dataset_binary = CriteoBinDataset( 328 | data_file=binary_data_file, 329 | counts_file=counts_file, 330 | batch_size=2048, 331 | ) 332 | from dlrm_data_pytorch import ( 333 | collate_wrapper_criteo_offset as collate_wrapper_criteo, 334 | CriteoDataset, 335 | ) 336 | 337 | binary_loader = torch.utils.data.DataLoader( 338 | dataset_binary, 339 | batch_size=None, 340 | shuffle=False, 341 | num_workers=0, 342 | collate_fn=None, 343 | pin_memory=False, 344 | drop_last=False, 345 | ) 346 | 347 | original_dataset = CriteoDataset( 348 | dataset="terabyte", 349 | max_ind_range=10 * 1000 * 1000, 350 | sub_sample_rate=1, 351 | randomize=True, 352 | split=args.split, 353 | raw_path=args.input_data_prefix, 354 | pro_data="dummy_string", 355 | memory_map=True, 356 | ) 357 | 358 | original_loader = torch.utils.data.DataLoader( 359 | original_dataset, 360 | batch_size=2048, 361 | shuffle=False, 362 | num_workers=0, 363 | collate_fn=collate_wrapper_criteo, 364 | pin_memory=False, 365 | drop_last=False, 366 | ) 367 | 368 | assert len(dataset_binary) == len(original_loader) 369 | for i, (old_batch, new_batch) in tqdm( 370 | enumerate(zip(original_loader, binary_loader)), total=len(dataset_binary) 371 | ): 372 | for j in range(len(new_batch)): 373 | if not np.array_equal(old_batch[j], new_batch[j]): 374 | raise ValueError("FAILED: Datasets not equal") 375 | if i > len(dataset_binary): 376 | break 377 | print("PASSED") 378 | 379 | 380 | if __name__ == "__main__": 381 | _test() 382 | _test_bin() 383 | -------------------------------------------------------------------------------- /extend_distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import builtins 7 | import os 8 | import sys 9 | 10 | import torch 11 | import torch.distributed as dist 12 | from torch.autograd import Function 13 | from torch.autograd.profiler import record_function 14 | from torch.nn.parallel import DistributedDataParallel as DDP 15 | 16 | 17 | try: 18 | import torch_ccl 19 | except ImportError as e: 20 | # print(e) 21 | torch_ccl = False 22 | 23 | try: 24 | import torch_ucc 25 | except ImportError as e: 26 | torch_ucc = False 27 | 28 | 29 | my_rank = -1 30 | my_size = -1 31 | my_local_rank = -1 32 | my_local_size = -1 33 | alltoall_supported = False 34 | a2a_impl = os.environ.get("DLRM_ALLTOALL_IMPL", "") 35 | 36 | myreq = None 37 | 38 | 39 | def env2int(env_list, default=-1): 40 | for e in env_list: 41 | val = int(os.environ.get(e, -1)) 42 | if val >= 0: 43 | return val 44 | return default 45 | 46 | 47 | def get_my_slice(n): 48 | k, m = divmod(n, my_size) 49 | return slice( 50 | my_rank * k + min(my_rank, m), (my_rank + 1) * k + min(my_rank + 1, m), 1 51 | ) 52 | 53 | 54 | def get_split_lengths(n): 55 | k, m = divmod(n, my_size) 56 | if m == 0: 57 | splits = None 58 | my_len = k 59 | else: 60 | splits = [(k + 1) if i < m else k for i in range(my_size)] 61 | my_len = splits[my_rank] 62 | return (my_len, splits) 63 | 64 | 65 | def init_distributed(rank=-1, local_rank=-1, size=-1, use_gpu=False, backend=""): 66 | global myreq 67 | global my_rank 68 | global my_size 69 | global my_local_rank 70 | global my_local_size 71 | global a2a_impl 72 | global alltoall_supported 73 | 74 | # guess MPI ranks from env (works for IMPI, OMPI and MVAPICH2) 75 | num_mpi_ranks = env2int( 76 | ["PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE", "WORLD_SIZE"] 77 | ) 78 | if backend == "" and num_mpi_ranks > 1: 79 | if torch_ccl and env2int(["CCL_WORKER_COUNT"]) > 0: 80 | backend = "ccl" 81 | elif use_gpu and dist.is_nccl_available(): 82 | backend = "nccl" 83 | elif dist.is_mpi_available(): 84 | backend = "mpi" 85 | else: 86 | print( 87 | "WARNING: MPI multi-process launch detected but PyTorch MPI backend not available." 88 | ) 89 | backend = "gloo" 90 | 91 | if backend != "": 92 | # guess Rank and size 93 | if rank == -1: 94 | rank = env2int( 95 | ["PMI_RANK", "OMPI_COMM_WORLD_RANK", "MV2_COMM_WORLD_RANK", "RANK"], 0 96 | ) 97 | if size == -1: 98 | size = env2int( 99 | [ 100 | "PMI_SIZE", 101 | "OMPI_COMM_WORLD_SIZE", 102 | "MV2_COMM_WORLD_SIZE", 103 | "WORLD_SIZE", 104 | ], 105 | 1, 106 | ) 107 | if not os.environ.get("RANK", None) and rank != -1: 108 | os.environ["RANK"] = str(rank) 109 | if not os.environ.get("WORLD_SIZE", None) and size != -1: 110 | os.environ["WORLD_SIZE"] = str(size) 111 | if not os.environ.get("MASTER_PORT", None): 112 | os.environ["MASTER_PORT"] = "29500" 113 | if not os.environ.get("MASTER_ADDR", None): 114 | local_size = env2int( 115 | [ 116 | "MPI_LOCALNRANKS", 117 | "OMPI_COMM_WORLD_LOCAL_SIZE", 118 | "MV2_COMM_WORLD_LOCAL_SIZE", 119 | ], 120 | 1, 121 | ) 122 | if local_size != size and backend != "mpi": 123 | print( 124 | "Warning: Looks like distributed multinode run but MASTER_ADDR env not set, using '127.0.0.1' as default" 125 | ) 126 | print( 127 | "If this run hangs, try exporting rank 0's hostname as MASTER_ADDR" 128 | ) 129 | os.environ["MASTER_ADDR"] = "127.0.0.1" 130 | 131 | if size > 1: 132 | if local_rank == -1: 133 | my_local_rank = env2int( 134 | [ 135 | "MPI_LOCALRANKID", 136 | "OMPI_COMM_WORLD_LOCAL_RANK", 137 | "MV2_COMM_WORLD_LOCAL_RANK", 138 | "LOCAL_RANK", 139 | ], 140 | 0, 141 | ) 142 | else: 143 | my_local_rank = local_rank 144 | my_local_size = env2int( 145 | [ 146 | "MPI_LOCALNRANKS", 147 | "OMPI_COMM_WORLD_LOCAL_SIZE", 148 | "MV2_COMM_WORLD_LOCAL_SIZE", 149 | ], 150 | 1, 151 | ) 152 | if use_gpu: 153 | if my_local_size > torch.cuda.device_count(): 154 | print( 155 | "Not sufficient GPUs available... local_size = %d, ngpus = %d" 156 | % (my_local_size, torch.cuda.device_count()) 157 | ) 158 | sys.exit(1) 159 | torch.cuda.set_device(my_local_rank) 160 | dist.init_process_group(backend, rank=rank, world_size=size) 161 | my_rank = dist.get_rank() 162 | my_size = dist.get_world_size() 163 | if my_rank == 0: 164 | print("Running on %d ranks using %s backend" % (my_size, backend)) 165 | if hasattr(dist, "all_to_all_single"): 166 | try: 167 | t = torch.zeros([4]) 168 | if use_gpu: 169 | t = t.cuda() 170 | dist.all_to_all_single(t, t) 171 | alltoall_supported = True 172 | except RuntimeError as err: 173 | print("fail to enable all_to_all_single primitive: %s" % err) 174 | if a2a_impl == "alltoall" and alltoall_supported == False: 175 | print( 176 | "Requested DLRM_ALLTOALL_IMPL=%s but backend %s does not support it, use scatter/gather based alltoall" 177 | % (a2a_impl, backend) 178 | ) 179 | a2a_impl = "scatter" 180 | if a2a_impl != "": 181 | print("Using DLRM_ALLTOALL_IMPL=%s" % a2a_impl) 182 | else: 183 | my_rank = 0 184 | my_size = 1 185 | my_local_rank = 0 186 | my_local_size = 1 187 | print_all( 188 | "world size: %d, current rank: %d, local rank: %d" 189 | % (my_size, my_rank, my_local_rank) 190 | ) 191 | myreq = Request() 192 | 193 | 194 | class Request(object): 195 | def __init__(self): 196 | self.req = None 197 | self.tensor = None 198 | self.WaitFunction = All2All_Scatter_Wait 199 | 200 | def wait(self): 201 | ret = self.WaitFunction.apply(*self.tensor) 202 | self.req = None 203 | self.tensor = None 204 | return ret 205 | 206 | 207 | class All2All_ScatterList_Req(Function): 208 | @staticmethod 209 | def forward(ctx, a2a_info, *inputs): 210 | global myreq 211 | batch_split_lengths = ( 212 | a2a_info.global_batch_partition_slices 213 | if a2a_info.global_batch_partition_slices 214 | else a2a_info.local_batch_num 215 | ) 216 | table_split_lengths = ( 217 | a2a_info.global_table_wise_parition_slices 218 | if a2a_info.global_table_wise_parition_slices 219 | else [a2a_info.local_table_num] * my_size 220 | ) 221 | gather_list = [] 222 | req_list = [] 223 | for i in range(my_size): 224 | for j in range(table_split_lengths[i]): 225 | out_tensor = inputs[0].new_empty( 226 | [a2a_info.local_batch_num, a2a_info.emb_dim] 227 | ) 228 | scatter_list = ( 229 | list(inputs[j].split(batch_split_lengths, dim=0)) 230 | if i == my_rank 231 | else [] 232 | ) 233 | req = dist.scatter(out_tensor, scatter_list, src=i, async_op=True) 234 | gather_list.append(out_tensor) 235 | req_list.append(req) 236 | myreq.req = req_list 237 | myreq.tensor = tuple(gather_list) 238 | myreq.a2a_info = a2a_info 239 | return myreq.tensor 240 | 241 | @staticmethod 242 | def backward(ctx, *grad_output): 243 | global myreq 244 | for r in myreq.req: 245 | r.wait() 246 | myreq.req = None 247 | grad_inputs = myreq.tensor 248 | myreq.tensor = None 249 | return (None, *grad_inputs) 250 | 251 | 252 | class All2All_ScatterList_Wait(Function): 253 | @staticmethod 254 | def forward(ctx, *output): 255 | global myreq 256 | ctx.a2a_info = myreq.a2a_info 257 | for r in myreq.req: 258 | r.wait() 259 | myreq.req = None 260 | myreq.tensor = None 261 | return output 262 | 263 | @staticmethod 264 | def backward(ctx, *grad_output): 265 | global myreq 266 | a2a_info = ctx.a2a_info 267 | grad_output = [t.contiguous() for t in grad_output] 268 | batch_split_lengths = ( 269 | a2a_info.global_batch_partition_slices 270 | if a2a_info.global_batch_partition_slices 271 | else [a2a_info.local_batch_num] * my_size 272 | ) 273 | per_rank_table_splits = ( 274 | a2a_info.global_table_wise_parition_slices 275 | if a2a_info.global_table_wise_parition_slices 276 | else [a2a_info.local_table_num] * my_size 277 | ) 278 | grad_inputs = [ 279 | grad_output[0].new_empty([ctx.a2a_info.batch_size, ctx.a2a_info.emb_dim]) 280 | for _ in range(a2a_info.local_table_num) 281 | ] 282 | req_list = [] 283 | ind = 0 284 | for i in range(my_size): 285 | for j in range(per_rank_table_splits[i]): 286 | gather_list = ( 287 | list(grad_inputs[j].split(batch_split_lengths, dim=0)) 288 | if i == my_rank 289 | else None 290 | ) 291 | req = dist.gather(grad_output[ind], gather_list, dst=i, async_op=True) 292 | req_list.append(req) 293 | ind += 1 294 | myreq.req = req_list 295 | myreq.tensor = grad_inputs 296 | return tuple(grad_output) 297 | 298 | 299 | class All2All_Scatter_Req(Function): 300 | @staticmethod 301 | def forward(ctx, a2a_info, *inputs): 302 | global myreq 303 | batch_split_lengths = ( 304 | a2a_info.global_batch_partition_slices 305 | if a2a_info.global_batch_partition_slices 306 | else a2a_info.local_batch_num 307 | ) 308 | table_split_lengths = ( 309 | a2a_info.global_table_wise_parition_slices 310 | if a2a_info.global_table_wise_parition_slices 311 | else [a2a_info.local_table_num] * my_size 312 | ) 313 | input = torch.cat(inputs, dim=1) 314 | scatter_list = list(input.split(batch_split_lengths, dim=0)) 315 | gather_list = [] 316 | req_list = [] 317 | for i in range(my_size): 318 | out_tensor = input.new_empty( 319 | [a2a_info.local_batch_num, table_split_lengths[i] * a2a_info.emb_dim] 320 | ) 321 | req = dist.scatter( 322 | out_tensor, scatter_list if i == my_rank else [], src=i, async_op=True 323 | ) 324 | gather_list.append(out_tensor) 325 | req_list.append(req) 326 | myreq.req = req_list 327 | myreq.tensor = tuple(gather_list) 328 | myreq.a2a_info = a2a_info 329 | ctx.a2a_info = a2a_info 330 | return myreq.tensor 331 | 332 | @staticmethod 333 | def backward(ctx, *grad_output): 334 | global myreq 335 | for r in myreq.req: 336 | r.wait() 337 | myreq.req = None 338 | grad_input = myreq.tensor 339 | grad_inputs = grad_input.split(ctx.a2a_info.emb_dim, dim=1) 340 | myreq.tensor = None 341 | return (None, *grad_inputs) 342 | 343 | 344 | class All2All_Scatter_Wait(Function): 345 | @staticmethod 346 | def forward(ctx, *output): 347 | global myreq 348 | ctx.a2a_info = myreq.a2a_info 349 | for r in myreq.req: 350 | r.wait() 351 | myreq.req = None 352 | myreq.tensor = None 353 | return output 354 | 355 | @staticmethod 356 | def backward(ctx, *grad_output): 357 | global myreq 358 | assert len(grad_output) == my_size 359 | scatter_list = [t.contiguous() for t in grad_output] 360 | a2a_info = ctx.a2a_info 361 | batch_split_lengths = ( 362 | a2a_info.global_batch_partition_slices 363 | if a2a_info.global_batch_partition_slices 364 | else a2a_info.local_batch_num 365 | ) 366 | table_split_lengths = ( 367 | a2a_info.global_table_wise_parition_slices 368 | if a2a_info.global_table_wise_parition_slices 369 | else [a2a_info.local_table_num] * my_size 370 | ) 371 | grad_input = grad_output[0].new_empty( 372 | [a2a_info.batch_size, a2a_info.emb_dim * a2a_info.local_table_num] 373 | ) 374 | gather_list = list(grad_input.split(batch_split_lengths, dim=0)) 375 | req_list = [] 376 | for i in range(my_size): 377 | req = dist.gather( 378 | scatter_list[i], 379 | gather_list if i == my_rank else [], 380 | dst=i, 381 | async_op=True, 382 | ) 383 | req_list.append(req) 384 | myreq.req = req_list 385 | myreq.tensor = grad_input 386 | return grad_output 387 | 388 | 389 | class All2All_Req(Function): 390 | @staticmethod 391 | def forward(ctx, a2a_info, *inputs): 392 | global myreq 393 | with record_function("DLRM alltoall_req_fwd_single"): 394 | batch_split_lengths = a2a_info.global_batch_partition_slices 395 | if batch_split_lengths: 396 | batch_split_lengths = [ 397 | m * a2a_info.emb_dim * a2a_info.local_table_num 398 | for m in batch_split_lengths 399 | ] 400 | table_split_lengths = a2a_info.global_table_wise_parition_slices 401 | if table_split_lengths: 402 | table_split_lengths = [ 403 | a2a_info.local_batch_num * e * a2a_info.emb_dim 404 | for e in table_split_lengths 405 | ] 406 | input = torch.cat(inputs, dim=1).view([-1]) 407 | output = input.new_empty( 408 | [ 409 | a2a_info.global_table_num 410 | * a2a_info.local_batch_num 411 | * a2a_info.emb_dim 412 | ] 413 | ) 414 | req = dist.all_to_all_single( 415 | output, input, table_split_lengths, batch_split_lengths, async_op=True 416 | ) 417 | 418 | myreq.req = req 419 | myreq.tensor = [] 420 | myreq.tensor.append(output) 421 | myreq.tensor = tuple(myreq.tensor) 422 | a2a_info.batch_split_lengths = batch_split_lengths 423 | a2a_info.table_split_lengths = table_split_lengths 424 | myreq.a2a_info = a2a_info 425 | ctx.a2a_info = a2a_info 426 | return myreq.tensor 427 | 428 | @staticmethod 429 | def backward(ctx, *grad_output): 430 | global myreq 431 | with record_function("DLRM alltoall_req_bwd_single"): 432 | a2a_info = ctx.a2a_info 433 | myreq.req.wait() 434 | myreq.req = None 435 | grad_input = myreq.tensor 436 | grad_inputs = grad_input.view([a2a_info.batch_size, -1]).split( 437 | a2a_info.emb_dim, dim=1 438 | ) 439 | grad_inputs = [gin.contiguous() for gin in grad_inputs] 440 | myreq.tensor = None 441 | return (None, *grad_inputs) 442 | 443 | 444 | class All2All_Wait(Function): 445 | @staticmethod 446 | def forward(ctx, *output): 447 | global myreq 448 | with record_function("DLRM alltoall_wait_fwd_single"): 449 | a2a_info = myreq.a2a_info 450 | ctx.a2a_info = a2a_info 451 | myreq.req.wait() 452 | myreq.req = None 453 | myreq.tensor = None 454 | table_split_lengths = ( 455 | a2a_info.table_split_lengths 456 | if a2a_info.table_split_lengths 457 | else a2a_info.local_table_num 458 | * a2a_info.local_batch_num 459 | * a2a_info.emb_dim 460 | ) 461 | outputs = output[0].split(table_split_lengths) 462 | outputs = tuple( 463 | [out.view([a2a_info.local_batch_num, -1]) for out in outputs] 464 | ) 465 | return outputs 466 | 467 | @staticmethod 468 | def backward(ctx, *grad_outputs): 469 | global myreq 470 | with record_function("DLRM alltoall_wait_bwd_single"): 471 | a2a_info = ctx.a2a_info 472 | grad_outputs = [gout.contiguous().view([-1]) for gout in grad_outputs] 473 | grad_output = torch.cat(grad_outputs) 474 | grad_input = grad_output.new_empty( 475 | [a2a_info.batch_size * a2a_info.local_table_num * a2a_info.emb_dim] 476 | ) 477 | req = dist.all_to_all_single( 478 | grad_input, 479 | grad_output, 480 | a2a_info.batch_split_lengths, 481 | a2a_info.table_split_lengths, 482 | async_op=True, 483 | ) 484 | myreq.req = req 485 | myreq.tensor = grad_input 486 | return (grad_output,) 487 | 488 | 489 | class AllGather(Function): 490 | @staticmethod 491 | def forward(ctx, input, global_lengths, dim=0): 492 | if not isinstance(global_lengths, (list, tuple)): 493 | global_lengths = [global_lengths] * my_size 494 | 495 | assert len(global_lengths) == my_size 496 | assert global_lengths[my_rank] == input.size(dim) 497 | local_start = sum(global_lengths[:my_rank]) 498 | 499 | output_size = list(input.size()) 500 | 501 | ctx.dim = dim 502 | ctx.local_start = local_start 503 | ctx.local_length = global_lengths[my_rank] 504 | 505 | input = input.contiguous() 506 | if dim == 0: 507 | out_len = sum(global_lengths) 508 | output_size[dim] = out_len 509 | output = input.new_empty(output_size) 510 | gather_list = list(output.split(global_lengths, dim=0)) 511 | else: 512 | gather_list = [torch.empty_like(input) for _ in range(my_size)] 513 | gather_list = [] 514 | for length in global_lengths: 515 | output_size[dim] = length 516 | gather_list.append(input.new_empty(output_size)) 517 | 518 | dist.all_gather(gather_list, input) 519 | 520 | if dim != 0: 521 | output = torch.cat(gather_list, dim=dim) 522 | 523 | return output 524 | 525 | @staticmethod 526 | def backward(ctx, grad_output): 527 | # print("Inside All2AllBackward") 528 | dim = ctx.dim 529 | start = ctx.local_start 530 | length = ctx.local_length 531 | 532 | grad_input = grad_output.narrow(dim, start, length) 533 | 534 | return (grad_input, None, None) 535 | 536 | 537 | class All2AllInfo(object): 538 | pass 539 | 540 | 541 | def alltoall(inputs, per_rank_table_splits): 542 | global myreq 543 | batch_size, emb_dim = inputs[0].size() 544 | a2a_info = All2AllInfo() 545 | a2a_info.local_table_num = len(inputs) 546 | a2a_info.global_table_wise_parition_slices = per_rank_table_splits 547 | ( 548 | a2a_info.local_batch_num, 549 | a2a_info.global_batch_partition_slices, 550 | ) = get_split_lengths(batch_size) 551 | a2a_info.emb_dim = emb_dim 552 | a2a_info.batch_size = batch_size 553 | a2a_info.global_table_num = ( 554 | sum(per_rank_table_splits) 555 | if per_rank_table_splits 556 | else a2a_info.local_table_num * my_size 557 | ) 558 | 559 | if a2a_impl == "" and alltoall_supported or a2a_impl == "alltoall": 560 | # print("Using All2All_Req") 561 | output = All2All_Req.apply(a2a_info, *inputs) 562 | myreq.WaitFunction = All2All_Wait 563 | elif a2a_impl == "" or a2a_impl == "scatter": 564 | # print("Using All2All_Scatter_Req") 565 | output = All2All_Scatter_Req.apply(a2a_info, *inputs) 566 | myreq.WaitFunction = All2All_Scatter_Wait 567 | elif a2a_impl == "scatter_list": 568 | # print("Using All2All_ScatterList_Req") 569 | output = All2All_ScatterList_Req.apply(a2a_info, *inputs) 570 | myreq.WaitFunction = All2All_ScatterList_Wait 571 | else: 572 | print( 573 | "Unknown value set for DLRM_ALLTOALL_IMPL (%s), " 574 | "please use one of [alltoall, scatter, scatter_list]" % a2a_impl 575 | ) 576 | return myreq 577 | 578 | 579 | def all_gather(input, lengths, dim=0): 580 | if not lengths: 581 | lengths = [input.size(0)] * my_size 582 | return AllGather.apply(input, lengths, dim) 583 | 584 | 585 | def barrier(): 586 | if my_size > 1: 587 | dist.barrier() 588 | 589 | 590 | # Override builtin print function to print only from rank 0 591 | orig_print = builtins.print 592 | 593 | 594 | def rank0_print(*args, **kwargs): 595 | if my_rank <= 0 or kwargs.get("print_all", False): 596 | orig_print(*args, **kwargs) 597 | 598 | 599 | builtins.print = rank0_print 600 | 601 | 602 | # Allow printing from all rank with explicit print_all 603 | def print_all(*args, **kwargs): 604 | orig_print(*args, **kwargs) 605 | -------------------------------------------------------------------------------- /input/dist_emb_0.log: -------------------------------------------------------------------------------- 1 | 1, 2, 3, 4, 5, 6 2 | 0, 1, 3, 4, 5 3 | 0.55, 0.64, 0.82, 0.91, 1.0 4 | -------------------------------------------------------------------------------- /input/dist_emb_1.log: -------------------------------------------------------------------------------- 1 | 1, 2, 3, 4, 5, 6 2 | 0, 1, 3, 4, 5 3 | 0.55, 0.64, 0.82, 0.91, 1.0 4 | -------------------------------------------------------------------------------- /input/dist_emb_2.log: -------------------------------------------------------------------------------- 1 | 1, 2, 3, 4, 5, 6 2 | 0, 1, 3, 4, 5 3 | 0.55, 0.64, 0.82, 0.91, 1.0 4 | -------------------------------------------------------------------------------- /input/trace.log: -------------------------------------------------------------------------------- 1 | 1, 2, 3, 4, 5, 3, 4, 1, 1, 6, 3 2 | -------------------------------------------------------------------------------- /kaggle_dac_loss_accuracy_plots.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/dlrm/ddd0fdcd3c9b71aa719e77266c912274a5692735/kaggle_dac_loss_accuracy_plots.png -------------------------------------------------------------------------------- /mlperf_logger.py: -------------------------------------------------------------------------------- 1 | # @lint-ignore-every LICENSELINT 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | """ 9 | Utilities for MLPerf logging 10 | """ 11 | 12 | import os 13 | 14 | import torch 15 | 16 | try: 17 | from mlperf_logging import mllog 18 | from mlperf_logging.mllog import constants 19 | 20 | _MLLOGGER = mllog.get_mllogger() 21 | except ImportError as error: 22 | print("Unable to import mlperf_logging, ", error) 23 | 24 | 25 | def log_start(*args, **kwargs): 26 | "log with start tag" 27 | _log_print(_MLLOGGER.start, *args, **kwargs) 28 | 29 | 30 | def log_end(*args, **kwargs): 31 | "log with end tag" 32 | _log_print(_MLLOGGER.end, *args, **kwargs) 33 | 34 | 35 | def log_event(*args, **kwargs): 36 | "log with event tag" 37 | _log_print(_MLLOGGER.event, *args, **kwargs) 38 | 39 | 40 | def _log_print(logger, *args, **kwargs): 41 | "makes mlperf logger aware of distributed execution" 42 | if "stack_offset" not in kwargs: 43 | kwargs["stack_offset"] = 3 44 | if "value" not in kwargs: 45 | kwargs["value"] = None 46 | 47 | if kwargs.pop("log_all_ranks", False): 48 | log = True 49 | else: 50 | log = get_rank() == 0 51 | 52 | if log: 53 | logger(*args, **kwargs) 54 | 55 | 56 | def config_logger(benchmark): 57 | "initiates mlperf logger" 58 | mllog.config( 59 | filename=os.path.join( 60 | os.path.dirname(os.path.abspath(__file__)), f"{benchmark}.log" 61 | ) 62 | ) 63 | _MLLOGGER.logger.propagate = False 64 | 65 | 66 | def barrier(): 67 | """ 68 | Works as a temporary distributed barrier, currently pytorch 69 | doesn't implement barrier for NCCL backend. 70 | Calls all_reduce on dummy tensor and synchronizes with GPU. 71 | """ 72 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 73 | torch.distributed.all_reduce(torch.cuda.FloatTensor(1)) 74 | torch.cuda.synchronize() 75 | 76 | 77 | def get_rank(): 78 | """ 79 | Gets distributed rank or returns zero if distributed is not initialized. 80 | """ 81 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 82 | rank = torch.distributed.get_rank() 83 | else: 84 | rank = 0 85 | return rank 86 | 87 | 88 | def mlperf_submission_log(benchmark): 89 | """ 90 | Logs information needed for MLPerf submission 91 | """ 92 | 93 | config_logger(benchmark) 94 | 95 | log_event( 96 | key=constants.SUBMISSION_BENCHMARK, 97 | value=benchmark, 98 | ) 99 | 100 | log_event(key=constants.SUBMISSION_ORG, value="reference_implementation") 101 | 102 | log_event(key=constants.SUBMISSION_DIVISION, value="closed") 103 | 104 | log_event(key=constants.SUBMISSION_STATUS, value="onprem") 105 | 106 | log_event(key=constants.SUBMISSION_PLATFORM, value="reference_implementation") 107 | 108 | log_event(key=constants.SUBMISSION_ENTRY, value="reference_implementation") 109 | 110 | log_event(key=constants.SUBMISSION_POC_NAME, value="reference_implementation") 111 | 112 | log_event(key=constants.SUBMISSION_POC_EMAIL, value="reference_implementation") 113 | -------------------------------------------------------------------------------- /optim/rwsadagrad.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import torch 8 | from torch.optim import Optimizer 9 | 10 | 11 | class RWSAdagrad(Optimizer): 12 | """Implements Row Wise Sparse Adagrad algorithm. 13 | 14 | Arguments: 15 | params (iterable): iterable of parameters to optimize or dicts defining 16 | parameter groups 17 | lr (float, optional): learning rate (default: 1e-2) 18 | lr_decay (float, optional): learning rate decay (default: 0) 19 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 20 | eps (float, optional): term added to the denominator to improve 21 | numerical stability (default: 1e-10) 22 | 23 | """ 24 | 25 | def __init__( 26 | self, 27 | params, 28 | lr=1e-2, 29 | lr_decay=0.0, 30 | weight_decay=0.0, 31 | initial_accumulator_value=0.0, 32 | eps=1e-10, 33 | ): 34 | if not 0.0 <= lr: 35 | raise ValueError("Invalid learning rate: {}".format(lr)) 36 | if not 0.0 <= lr_decay: 37 | raise ValueError("Invalid lr_decay value: {}".format(lr_decay)) 38 | if not 0.0 <= weight_decay: 39 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 40 | if not 0.0 <= initial_accumulator_value: 41 | raise ValueError( 42 | "Invalid initial_accumulator_value value: {}".format( 43 | initial_accumulator_value 44 | ) 45 | ) 46 | if not 0.0 <= eps: 47 | raise ValueError("Invalid epsilon value: {}".format(eps)) 48 | 49 | self.defaults = dict( 50 | lr=lr, 51 | lr_decay=lr_decay, 52 | eps=eps, 53 | weight_decay=weight_decay, 54 | initial_accumulator_value=initial_accumulator_value, 55 | ) 56 | super(RWSAdagrad, self).__init__(params, self.defaults) 57 | 58 | self.momentum_initialized = False 59 | 60 | for group in self.param_groups: 61 | for p in group["params"]: 62 | self.state[p]["step"] = 0 63 | 64 | def share_memory(self): 65 | for group in self.param_groups: 66 | for p in group["params"]: 67 | state = self.state[p] 68 | if p.grad.data.is_sparse: 69 | state["momentum"].share_memory_() 70 | else: 71 | state["sum"].share_memory_() 72 | 73 | def step(self, closure=None): 74 | """Performs a single optimization step. 75 | 76 | Arguments: 77 | closure (callable, optional): A closure that reevaluates the model 78 | and returns the loss. 79 | """ 80 | loss = None 81 | if closure is not None: 82 | loss = closure() 83 | 84 | for group in self.param_groups: 85 | for p in group["params"]: 86 | if p.grad is None: 87 | continue 88 | 89 | if not self.momentum_initialized: 90 | if p.grad.data.is_sparse: 91 | self.state[p]["momentum"] = torch.full( 92 | [p.data.shape[0]], 93 | self.defaults["initial_accumulator_value"], 94 | dtype=torch.float32, 95 | ) 96 | else: 97 | self.state[p]["sum"] = torch.full_like( 98 | p.data, 99 | self.defaults["initial_accumulator_value"], 100 | dtype=torch.float32, 101 | ) 102 | 103 | grad = p.grad 104 | state = self.state[p] 105 | 106 | state["step"] += 1 107 | 108 | if group["weight_decay"] != 0: 109 | if p.grad.data.is_sparse: 110 | raise RuntimeError( 111 | "weight_decay option is not compatible with sparse gradients" 112 | ) 113 | grad = grad.add(group["weight_decay"], p.data) 114 | 115 | clr = group["lr"] / (1.0 + (state["step"] - 1.0) * group["lr_decay"]) 116 | 117 | if grad.is_sparse: 118 | grad = ( 119 | grad.coalesce() 120 | ) # the update is non-linear so indices must be unique 121 | grad_indices = grad._indices() 122 | grad_values = grad._values() 123 | size = grad.size() 124 | 125 | def make_sparse(values, row_wise): 126 | constructor = grad.new 127 | matrix_size = [size[0]] if row_wise else size 128 | return constructor(grad_indices, values, matrix_size) 129 | 130 | if grad_values.numel() > 0: 131 | momentum_update = make_sparse( 132 | grad_values.pow(2).mean(dim=1), True 133 | ) 134 | state["momentum"].add_(momentum_update) # update momentum 135 | std = state["momentum"].sparse_mask(momentum_update.coalesce()) 136 | std_values = std._values().sqrt_().add_(group["eps"]) 137 | p.data.add_( 138 | make_sparse( 139 | grad_values / std_values.view(std_values.size()[0], 1), 140 | False, 141 | ), 142 | alpha=-clr, 143 | ) 144 | 145 | else: 146 | state["sum"].addcmul_(grad, grad, value=1.0) 147 | std = state["sum"].sqrt().add_(group["eps"]) 148 | p.data.addcdiv_(grad, std, value=-clr) 149 | 150 | self.momentum_initialized = True 151 | 152 | return loss 153 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | future 2 | numpy 3 | onnx 4 | pydot 5 | torch 6 | torchviz 7 | scikit-learn 8 | tqdm 9 | torchrec-nightly 10 | torchx-nightly 11 | -------------------------------------------------------------------------------- /terabyte_0875_loss_accuracy_plots.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/dlrm/ddd0fdcd3c9b71aa719e77266c912274a5692735/terabyte_0875_loss_accuracy_plots.png -------------------------------------------------------------------------------- /test/dlrm_s_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | #WARNING: must have compiled PyTorch and caffe2 8 | 9 | #check if extra argument is passed to the test 10 | if [[ $# == 1 ]]; then 11 | dlrm_extra_option=$1 12 | else 13 | dlrm_extra_option="" 14 | fi 15 | #echo $dlrm_extra_option 16 | 17 | dlrm_py="python dlrm_s_pytorch.py" 18 | dlrm_c2="python dlrm_s_caffe2.py" 19 | 20 | echo "Running commands ..." 21 | #run pytorch 22 | echo $dlrm_py 23 | $dlrm_py --mini-batch-size=1 --data-size=1 --nepochs=1 --arch-interaction-op=dot --learning-rate=0.1 --debug-mode $dlrm_extra_option > ppp1 24 | $dlrm_py --mini-batch-size=2 --data-size=4 --nepochs=1 --arch-interaction-op=dot --learning-rate=0.1 --debug-mode $dlrm_extra_option > ppp2 25 | $dlrm_py --mini-batch-size=2 --data-size=5 --nepochs=1 --arch-interaction-op=dot --learning-rate=0.1 --debug-mode $dlrm_extra_option > ppp3 26 | $dlrm_py --mini-batch-size=2 --data-size=5 --nepochs=3 --arch-interaction-op=dot --learning-rate=0.1 --debug-mode $dlrm_extra_option > ppp4 27 | 28 | #run caffe2 29 | echo $dlrm_c2 30 | $dlrm_c2 --mini-batch-size=1 --data-size=1 --nepochs=1 --arch-interaction-op=dot --learning-rate=0.1 --debug-mode $dlrm_extra_option > ccc1 31 | $dlrm_c2 --mini-batch-size=2 --data-size=4 --nepochs=1 --arch-interaction-op=dot --learning-rate=0.1 --debug-mode $dlrm_extra_option > ccc2 32 | $dlrm_c2 --mini-batch-size=2 --data-size=5 --nepochs=1 --arch-interaction-op=dot --learning-rate=0.1 --debug-mode $dlrm_extra_option > ccc3 33 | $dlrm_c2 --mini-batch-size=2 --data-size=5 --nepochs=3 --arch-interaction-op=dot --learning-rate=0.1 --debug-mode $dlrm_extra_option > ccc4 34 | 35 | echo "Checking results ..." 36 | #check results 37 | #WARNING: correct test will have no difference in numeric values 38 | #(but might have some verbal difference, e.g. due to warnnings) 39 | #in the output file 40 | echo "diff test1 (no numeric values in the output = SUCCESS)" 41 | diff ccc1 ppp1 42 | echo "diff test2 (no numeric values in the output = SUCCESS)" 43 | diff ccc2 ppp2 44 | echo "diff test3 (no numeric values in the output = SUCCESS)" 45 | diff ccc3 ppp3 46 | echo "diff test4 (no numeric values in the output = SUCCESS)" 47 | diff ccc4 ppp4 48 | -------------------------------------------------------------------------------- /torchrec_dlrm/Dockerfile: -------------------------------------------------------------------------------- 1 | ARG FROM_IMAGE_NAME=pytorch/pytorch:1.13.1-cuda11.6-cudnn8-runtime 2 | FROM ${FROM_IMAGE_NAME} 3 | 4 | WORKDIR /workspace/torchrec_dlrm 5 | COPY . . 6 | 7 | RUN pip install --no-cache-dir -r requirements.txt 8 | -------------------------------------------------------------------------------- /torchrec_dlrm/README.MD: -------------------------------------------------------------------------------- 1 | # TorchRec DLRM Example 2 | 3 | `dlrm_main.py` trains, validates, and tests a [Deep Learning Recommendation Model](https://arxiv.org/abs/1906.00091) (DLRM) with TorchRec. The DLRM model contains both data parallel components (e.g. multi-layer perceptrons & interaction arch) and model parallel components (e.g. embedding tables). The DLRM model is pipelined so that dataloading, data-parallel to model-parallel comms, and forward/backward are overlapped. Can be run with either a random dataloader or [Criteo 1 TB click logs dataset](https://ailab.criteo.com/download-criteo-1tb-click-logs-dataset/). 4 | 5 | It has been tested on the following cloud instance types: 6 | | Cloud | Instance Type | GPUs | vCPUs | Memory (GB) | 7 | | ------ | ------------------- | ---------------- | ----- | ----------- | 8 | | AWS | p4d.24xlarge | 8 x A100 (40GB) | 96 | 1152 | 9 | | Azure | Standard_ND96asr_v4 | 8 x A100 (40GB) | 96 | 900 | 10 | | GCP | a2-megagpu-16g | 16 x A100 (40GB) | 96 | 1300 | 11 | 12 | A basic understanding of [TorchRec](https://github.com/pytorch/torchrec) will help in understanding `dlrm_main.py`. See this [tutorial](https://pytorch.org/tutorials/intermediate/torchrec_tutorial.html). 13 | 14 | # Running 15 | 16 | ## Install dependencies 17 | `pip install tqdm torchmetrics` 18 | 19 | ## Torchx 20 | We recommend using [torchx](https://pytorch.org/torchx/main/quickstart.html) to run. Here we use the [DDP builtin](https://pytorch.org/torchx/main/components/distributed.html) 21 | 22 | 1. pip install torchx 23 | 2. (optional) setup a slurm or kubernetes cluster 24 | 3. 25 | a. locally: `torchx run -s local_cwd dist.ddp -j 1x2 --script dlrm_main.py` 26 | b. remotely: `torchx run -s slurm dist.ddp -j 1x8 --script dlrm_main.py` 27 | 28 | ## TorchRun 29 | You can also use [torchrun](https://pytorch.org/docs/stable/elastic/run.html). 30 | * e.g. `torchrun --nnodes 1 --nproc_per_node 2 --rdzv_backend c10d --rdzv_endpoint localhost --rdzv_id 54321 --role trainer dlrm_main.py` 31 | 32 | 33 | ## Preliminary Training Results 34 | 35 | **Setup:** 36 | * Dataset: Criteo 1TB Click Logs dataset 37 | * CUDA 11.0, NCCL 2.10.3. 38 | * AWS p4d24xlarge instances, each with 8 40GB NVIDIA A100s. 39 | 40 | **Results** 41 | 42 | Common settings across all runs: 43 | 44 | ``` 45 | --num_embeddings_per_feature 40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36 --embedding_dim 128 --pin_memory --over_arch_layer_sizes 1024,1024,512,256,1 --dense_arch_layer_sizes 512,256,128 --epochs 1 46 | ``` 47 | 48 | |Number of GPUs|Collective Size of Embedding Tables (GiB)|Local Batch Size|Global Batch Size|Learning Rate|Interaction Type|Optimizer|AUROC over Val Set After 1 Epoch|AUROC Over Test Set After 1 Epoch|Training speed|Time to Train 1 Epoch|Unique Flags| 49 | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | 50 | |8|104.54|256|2,048|1.0|Dot product interaction|SGD|0.8032|0.8030|~100.0 batches/s == ~204,800 samples/s|6h30m08s |`--batch_size 256 --learning_rate 1.0`| 51 | |8|104.54|2,048|16,384|0.006|Dot product interaction|Adagrad|0.8021|0.7959|~56.5 batches/s == ~925,696 samples/s|1h16m15s |`--batch_size 2048 --learning_rate 0.006 --adagrad` | 52 | |8|104.54|2,048|16,384|0.006|DCN v2|Adagrad|0.8035|0.7973|~55.0 batches/s == ~901,120 samples/s|1h20m21s |`--batch_size 2048 --learning_rate 0.006 --adagrad --interaction_type=dcn` | 53 | |8|104.54|16,384|131,072|0.006|DCN v2|Adagrad|0.8025|0.7963|~9.08 batches/s == ~1,190,128 samples/s|58m 49s |`--batch_size 16384 --learning_rate 0.006 --adagrad --interaction_type=dcn`| 54 | 55 | Training speed is calculated using the formula: `average it/s * local batch size * number of GPUs used`. The benchmark displays `it/s` measurements 56 | during the run. 57 | 58 | **Reproduce** 59 | 60 | Run the following command to reproduce the results for a single node (8 GPUs) on AWS. This command makes use of the `aws_component.py` script. 61 | 62 | Ensure to: 63 | - set $PATH_TO_1TB_NUMPY_FILES to the path with the pre-processed .npy files of the Criteo 1TB dataset. 64 | - set $TRAIN_QUEUE to the partition that handles training jobs 65 | 66 | **NVTabular** 67 | For an alternative way of preprocessing the dataset using NVTabular, which can decrease the time required from several days to just hours. See the run instructions [here] (https://github.com/pytorch/torchrec/tree/main/examples/nvt_dataloader). 68 | 69 | Preprocessing command (numpy): 70 | 71 | After downloading and uncompressing the [Criteo 1TB Click Logs dataset](consisting of 24 files from [day 0](https://storage.googleapis.com/criteo-cail-datasets/day_0.gz) to [day 23](https://storage.googleapis.com/criteo-cail-datasets/day_23.gz)), process the raw tsv files into the proper format for training by running `./scripts/process_Criteo_1TB_Click_Logs_dataset.sh` with necessary command line arguments. 72 | 73 | Example usage: 74 | 75 | ``` 76 | bash ./scripts/process_Criteo_1TB_Click_Logs_dataset.sh \ 77 | ./criteo_1tb/raw_input_dataset_dir \ 78 | ./criteo_1tb/temp_intermediate_files_dir \ 79 | ./criteo_1tb/numpy_contiguous_shuffled_output_dataset_dir 80 | ``` 81 | 82 | The script requires 700GB of RAM and takes 1-2 days to run. We currently have features in development to reduce the preproccessing time and memory overhead. 83 | MD5 checksums of the expected final preprocessed dataset files are in md5sums_preprocessed_criteo_click_logs_dataset.txt. 84 | 85 | We are working on improving this experience, for updates about this see https://github.com/pytorch/torchrec/tree/main/examples/nvt_dataloader 86 | 87 | 88 | Example command: 89 | ``` 90 | torchx run --scheduler slurm --scheduler_args partition=$TRAIN_QUEUE,time=5:00:00 aws_component.py:run_dlrm_main --num_trainers=8 -- --pin_memory --batch_size 2048 --epochs 1 --num_embeddings_per_feature "45833188,36746,17245,7413,20243,3,7114,1441,62,29275261,1572176,345138,10,2209,11267,128,4,974,14,48937457,11316796,40094537,452104,12606,104,35" --embedding_dim 128 --dense_arch_layer_sizes 512,256,128 --over_arch_layer_sizes 1024,1024,512,256,1 --in_memory_binary_criteo_path $PATH_TO_1TB_NUMPY_FILES --learning_rate 15.0 91 | ``` 92 | Upon scheduling the job, there should be an output that looks like this: 93 | 94 | ``` 95 | warnings.warn( 96 | slurm://torchx/14731 97 | torchx 2022-01-07 21:06:59 INFO Launched app: slurm://torchx/14731 98 | torchx 2022-01-07 21:06:59 INFO AppStatus: 99 | msg: '' 100 | num_restarts: -1 101 | roles: [] 102 | state: UNKNOWN (7) 103 | structured_error_msg: 104 | ui_url: null 105 | 106 | torchx 2022-01-07 21:06:59 INFO Job URL: None 107 | ``` 108 | 109 | In this example, the job was launched to: `slurm://torchx/14731`. 110 | 111 | Run the following commands to check the status of your job and read the logs: 112 | 113 | ``` 114 | # Status should be "RUNNING" if properly scheduled 115 | torchx status slurm://torchx/14731 116 | 117 | # Log file was automatically created in the directory where you launched the job from 118 | cat slurm-14731.out 119 | 120 | ``` 121 | 122 | The results from the training can be found in the log file (e.g. `slurm-14731.out`). 123 | 124 | **Debugging** 125 | 126 | The `--validation_freq_within_epoch x` parameter can be used to print the AUROC every `x` iterations through an epoch. 127 | 128 | The in-memory dataloader can take approximately 20-30 minutes to load the data into memory before training starts. The 129 | `--mmap_mode` parameter can be used to load data from disk which reduces start-up time for training at the cost 130 | of QPS. 131 | 132 | **Inference** 133 | A module which can be used for DLRM inference exists [here](https://github.com/pytorch/torchrec/blob/main/examples/inference/dlrm_predict.py#L49). Please see the [TorchRec inference examples](https://github.com/pytorch/torchrec/tree/main/examples/inference) for more information. 134 | 135 | # Running the MLPerf DLRM v2 benchmark 136 | 137 | ## Create the synthetic multi-hot dataset 138 | ### Step 1: Download and uncompressing the [Criteo 1TB Click Logs dataset](https://storage.googleapis.com/criteo-cail-datasets/day_{0-23}.gz) 139 | 140 | ### Step 2: Run the 1TB Criteo Preprocess script. 141 | Example usage: 142 | 143 | ``` 144 | bash ./scripts/process_Criteo_1TB_Click_Logs_dataset.sh \ 145 | ./criteo_1tb/raw_input_dataset_dir \ 146 | ./criteo_1tb/temp_intermediate_files_dir \ 147 | ./criteo_1tb/numpy_contiguous_shuffled_output_dataset_dir 148 | ``` 149 | 150 | The script requires 700GB of RAM and takes 1-2 days to run. MD5 checksums for the output dataset files are in md5sums_preprocessed_criteo_click_logs_dataset.txt. 151 | 152 | ### Step 3: Run the `materialize_synthetic_multihot_dataset.py` script 153 | #### Single-process version: 154 | ``` 155 | python materialize_synthetic_multihot_dataset.py \ 156 | --in_memory_binary_criteo_path $PREPROCESSED_CRITEO_1TB_CLICK_LOGS_DATASET_PATH \ 157 | --output_path $MATERIALIZED_DATASET_PATH \ 158 | --num_embeddings_per_feature 40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36 \ 159 | --multi_hot_sizes 3,2,1,2,6,1,1,1,1,7,3,8,1,6,9,5,1,1,1,12,100,27,10,3,1,1 \ 160 | --multi_hot_distribution_type uniform 161 | ``` 162 | #### Multiple-processes version: 163 | ``` 164 | torchx run -s local_cwd dist.ddp -j 1x8 --script -- materialize_synthetic_multihot_dataset.py -- \ 165 | --in_memory_binary_criteo_path $PREPROCESSED_CRITEO_1TB_CLICK_LOGS_DATASET_PATH \ 166 | --output_path $MATERIALIZED_DATASET_PATH \ 167 | --num_embeddings_per_feature 40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36 \ 168 | --multi_hot_sizes 3,2,1,2,6,1,1,1,1,7,3,8,1,6,9,5,1,1,1,12,100,27,10,3,1,1 \ 169 | --multi_hot_distribution_type uniform 170 | ``` 171 | 172 | ### Run the MLPerf DLRM v2 benchmark, which uses the materialized multi-hot dataset 173 | Example running 8 GPUs: 174 | ``` 175 | export MULTIHOT_PREPROCESSED_DATASET=$your_path_here 176 | export TOTAL_TRAINING_SAMPLES=4195197692 ; 177 | export GLOBAL_BATCH_SIZE=65536 ; 178 | export WORLD_SIZE=8 ; 179 | torchx run -s local_cwd dist.ddp -j 1x8 --script dlrm_main.py -- \ 180 | --embedding_dim 128 \ 181 | --dense_arch_layer_sizes 512,256,128 \ 182 | --over_arch_layer_sizes 1024,1024,512,256,1 \ 183 | --synthetic_multi_hot_criteo_path $MULTIHOT_PREPROCESSED_DATASET \ 184 | --num_embeddings_per_feature 40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36 \ 185 | --validation_freq_within_epoch $((TOTAL_TRAINING_SAMPLES / (GLOBAL_BATCH_SIZE * 20))) \ 186 | --epochs 1 \ 187 | --pin_memory \ 188 | --mmap_mode \ 189 | --batch_size $((GLOBAL_BATCH_SIZE / WORLD_SIZE)) \ 190 | --interaction_type=dcn \ 191 | --dcn_num_layers=3 \ 192 | --dcn_low_rank_dim=512 \ 193 | --adagrad \ 194 | --learning_rate 0.005 195 | ``` 196 | Note: The proposed target AUROC to reach within one epoch is 0.8030. 197 | 198 | ## (Alternative method that trains multi-hot data generated on-the-fly) 199 | 200 | It is possible to use the 1-hot preprocessed dataset (the output of `./scripts/process_Criteo_1TB_Click_Logs_dataset.sh`) to create the synthetic multi-hot data on-the-fly during training. This is useful if your system does not have the space to store the 3.8 TB materialized multi-hot dataset. Example run command: 201 | 202 | ``` 203 | export PREPROCESSED_DATASET=$insert_your_path_here 204 | export TOTAL_TRAINING_SAMPLES=4195197692 ; 205 | export BATCHSIZE=65536 ; 206 | export WORLD_SIZE=8 ; 207 | torchx run -s local_cwd dist.ddp -j 1x8 --script dlrm_main.py -- \ 208 | --embedding_dim 128 \ 209 | --dense_arch_layer_sizes 512,256,128 \ 210 | --over_arch_layer_sizes 1024,1024,512,256,1 \ 211 | --in_memory_binary_criteo_path $PREPROCESSED_DATASET \ 212 | --num_embeddings_per_feature 40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36 \ 213 | --validation_freq_within_epoch $((TOTAL_TRAINING_SAMPLES / (BATCHSIZE * 20))) \ 214 | --epochs 1 \ 215 | --pin_memory \ 216 | --mmap_mode \ 217 | --batch_size $((GLOBAL_BATCH_SIZE / WORLD_SIZE)) \ 218 | --interaction_type=dcn \ 219 | --dcn_num_layers=3 \ 220 | --dcn_low_rank_dim=512 \ 221 | --adagrad \ 222 | --learning_rate 0.005 \ 223 | --multi_hot_distribution_type uniform \ 224 | --multi_hot_sizes=3,2,1,2,6,1,1,1,1,7,3,8,1,6,9,5,1,1,1,12,100,27,10,3,1,1 225 | ``` 226 | # Replicating the MLPerf DLRM v1 benchmark using the [TorchRec-based implementation](./torchrec_dlrm/dlrm_main.py) 227 | 228 | ## Create the 1-hot preprocessed dataset 229 | ### Step 1: [Download](./torchrec_dlrm/scripts/download_Criteo_1TB_Click_Logs_dataset.sh) and uncompressing the Criteo 1TB Click Logs dataset (24 files from [day 0](https://storage.googleapis.com/criteo-cail-datasets/day_0.gz) to [day 23](https://storage.googleapis.com/criteo-cail-datasets/day_23.gz)) 230 | 231 | ### Step 2: Run the 1TB Criteo Preprocess script. 232 | Example usage: 233 | 234 | ``` 235 | bash ./scripts/process_Criteo_1TB_Click_Logs_dataset.sh \ 236 | ./criteo_1tb/raw_input_dataset_dir \ 237 | ./criteo_1tb/temp_intermediate_files_dir \ 238 | ./criteo_1tb/numpy_contiguous_shuffled_output_dataset_dir 239 | ``` 240 | 241 | The script requires 700GB of RAM and takes 1-2 days to run. MD5 checksums for the output dataset files are in md5sums_preprocessed_criteo_click_logs_dataset.txt. 242 | 243 | ## Run the TorchRec-based implementation with the MLPerf DLRM v1 benchmark settings 244 | 245 | Example running 8 GPUs: 246 | ``` 247 | export PREPROCESSED_DATASET=$insert_your_path_here 248 | export TOTAL_TRAINING_SAMPLES=4195197692 ; 249 | export GLOBAL_BATCH_SIZE=16384 ; 250 | export WORLD_SIZE=8 ; 251 | torchx run -s local_cwd dist.ddp -j 1x8 --script dlrm_main.py -- \ 252 | --embedding_dim 128 \ 253 | --dense_arch_layer_sizes 512,256,128 \ 254 | --over_arch_layer_sizes 1024,1024,512,256,1 \ 255 | --in_memory_binary_criteo_path $PREPROCESSED_DATASET \ 256 | --num_embeddings_per_feature 40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36 \ 257 | --validation_freq_within_epoch $((TOTAL_TRAINING_SAMPLES / (GLOBAL_BATCH_SIZE * 20))) \ 258 | --epochs 1 \ 259 | --pin_memory \ 260 | --mmap_mode \ 261 | --batch_size $((GLOBAL_BATCH_SIZE / WORLD_SIZE)) \ 262 | --learning_rate 1.0 263 | ``` 264 | ## Comparison of MLPerf DLRM Benchmark Settings: v1 vs. v2: 265 | 266 | ||v1|v2| 267 | | --- | --- | --- | 268 | |Optimizer|SGD|Adagrad| 269 | |Learning rate|1.0|0.005| 270 | |Batch size|16384|65536| 271 | |Interaction type|Dot product|DCN v2| 272 | |Benchmark Script|[v1](https://github.com/facebookresearch/dlrm/blob/mlperf/dlrm_s_pytorch.py)|[v2 (using TorchRec)](./torchrec_dlrm/dlrm_main.py)| 273 | |Dataset preprocessing scripts/instructions|[v1](https://github.com/facebookresearch/dlrm/blob/main/data_utils.py)|[v2](https://github.com/facebookresearch/dlrm/tree/main/torchrec_dlrm#create-the-synthetic-multi-hot-dataset)| 274 | |Synthetically-generated multi-hot sparse features|No (Uses 1-hot sparse features) |Yes (synthetically-generatated multi-hot sparse features generated from the original 1-hot sparse features)| 275 | 276 | # Criteo Kaggle Display Advertising Challenge dataset usage. 277 | 278 | ### Preliminary 279 | - Python >= 3.9 280 | - Cuda >= 12.0 281 | 282 | ### Setup environment 283 | Install PyTorch nightly version 284 | ```bash 285 | pip install torch --index-url https://download.pytorch.org/whl/nightly/cu126 286 | ``` 287 | Install FBGEMM-GPU 288 | ```bash 289 | pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/cu126 290 | ``` 291 | Install torchrec from local build 292 | ```bash 293 | git clone https://github.com/pytorch/torchrec.git 294 | python -m pip install -e torchrec 295 | ``` 296 | Install additional dependencies 297 | ```bash 298 | pip install -r requirements.txt 299 | ``` 300 | 301 | ### Download the dataset. 302 | ``` 303 | wget http://go.criteo.net/criteo-research-kaggle-display-advertising-challenge-dataset.tar.gz 304 | ``` 305 | ### Uncompress 306 | ``` 307 | tar zxvf criteo-research-kaggle-display-advertising-challenge-dataset.tar.gz 308 | ``` 309 | ### Preprocess the dataset to numpy files. 310 | ``` 311 | python -m torchrec.datasets.scripts.npy_preproc_criteo --input_dir $INPUT_PATH --output_dir $OUTPUT_PATH --dataset_name criteo_kaggle 312 | ``` 313 | ### Run the benchmark. 314 | ``` 315 | export PREPROCESSED_DATASET=$insert_your_path_here 316 | export GLOBAL_BATCH_SIZE=16384 ; 317 | export WORLD_SIZE=8 ; 318 | export LEARNING_RATE=0.5 ; 319 | torchx run -s local_cwd dist.ddp -j 1x8 --script dlrm_main.py -- \ 320 | --in_memory_binary_criteo_path $PREPROCESSED_DATASET \ 321 | --pin_memory \ 322 | --mmap_mode \ 323 | --batch_size $((GLOBAL_BATCH_SIZE / WORLD_SIZE)) \ 324 | --learning_rate $LEARNING_RATE \ 325 | --dataset_name criteo_kaggle \ 326 | --num_embeddings_per_feature 40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36 \ 327 | --embedding_dim 128 \ 328 | --over_arch_layer_sizes 1024,1024,512,256,1 \ 329 | --dense_arch_layer_sizes 512,256,128 \ 330 | --epochs 1 \ 331 | --validation_freq_within_epoch 12802 332 | ``` 333 | -------------------------------------------------------------------------------- /torchrec_dlrm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/dlrm/ddd0fdcd3c9b71aa719e77266c912274a5692735/torchrec_dlrm/__init__.py -------------------------------------------------------------------------------- /torchrec_dlrm/aws_component.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | 9 | import torchx.specs as specs 10 | from torchx.components.dist import ddp 11 | from torchx.specs.api import Resource 12 | 13 | 14 | def run_dlrm_main(num_trainers: int = 8, *script_args: str) -> specs.AppDef: 15 | """ 16 | Args: 17 | num_trainers: The number of trainers to use. 18 | script_args: A variable number of parameters to provide dlrm_main.py. 19 | """ 20 | cwd = os.getcwd() 21 | entrypoint = os.path.join(cwd, "dlrm_main.py") 22 | 23 | user = os.environ.get("USER") 24 | image = f"/data/home/{user}" 25 | 26 | if num_trainers > 8 and num_trainers % 8 != 0: 27 | raise ValueError( 28 | "Trainer jobs spanning multiple hosts must be in multiples of 8." 29 | ) 30 | nproc_per_node = 8 if num_trainers >= 8 else num_trainers 31 | num_replicas = max(num_trainers // 8, 1) 32 | 33 | return ddp( 34 | *script_args, 35 | name="train_dlrm", 36 | image=image, 37 | # AWS p4d instance (https://aws.amazon.com/ec2/instance-types/p4/). 38 | cpu=96, 39 | gpu=8, 40 | memMB=-1, 41 | script=entrypoint, 42 | j=f"{num_replicas}x{nproc_per_node}", 43 | ) 44 | -------------------------------------------------------------------------------- /torchrec_dlrm/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/dlrm/ddd0fdcd3c9b71aa719e77266c912274a5692735/torchrec_dlrm/data/__init__.py -------------------------------------------------------------------------------- /torchrec_dlrm/data/dlrm_dataloader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import os 9 | from typing import List 10 | 11 | from torch import distributed as dist 12 | from torch.utils.data import DataLoader 13 | from torchrec.datasets.criteo import ( 14 | CAT_FEATURE_COUNT, 15 | DAYS, 16 | DEFAULT_CAT_NAMES, 17 | DEFAULT_INT_NAMES, 18 | InMemoryBinaryCriteoIterDataPipe, 19 | ) 20 | from torchrec.datasets.random import RandomRecDataset 21 | 22 | # OSS import 23 | try: 24 | # pyre-ignore[21] 25 | # @manual=//ai_codesign/benchmarks/dlrm/torchrec_dlrm/data:multi_hot_criteo 26 | from data.multi_hot_criteo import MultiHotCriteoIterDataPipe 27 | 28 | except ImportError: 29 | pass 30 | 31 | # internal import 32 | try: 33 | from .multi_hot_criteo import MultiHotCriteoIterDataPipe # noqa F811 34 | except ImportError: 35 | pass 36 | 37 | STAGES = ["train", "val", "test"] 38 | 39 | 40 | def _get_random_dataloader( 41 | args: argparse.Namespace, 42 | stage: str, 43 | ) -> DataLoader: 44 | attr = f"limit_{stage}_batches" 45 | num_batches = getattr(args, attr) 46 | if stage in ["val", "test"] and args.test_batch_size is not None: 47 | batch_size = args.test_batch_size 48 | else: 49 | batch_size = args.batch_size 50 | return DataLoader( 51 | RandomRecDataset( 52 | keys=DEFAULT_CAT_NAMES, 53 | batch_size=batch_size, 54 | hash_size=args.num_embeddings, 55 | hash_sizes=( 56 | args.num_embeddings_per_feature 57 | if hasattr(args, "num_embeddings_per_feature") 58 | else None 59 | ), 60 | manual_seed=getattr(args, "seed", None), 61 | ids_per_feature=1, 62 | num_dense=len(DEFAULT_INT_NAMES), 63 | num_batches=num_batches, 64 | ), 65 | batch_size=None, 66 | batch_sampler=None, 67 | pin_memory=args.pin_memory, 68 | num_workers=0, 69 | ) 70 | 71 | 72 | def _get_in_memory_dataloader( 73 | args: argparse.Namespace, 74 | stage: str, 75 | ) -> DataLoader: 76 | if args.in_memory_binary_criteo_path is not None: 77 | dir_path = args.in_memory_binary_criteo_path 78 | sparse_part = "sparse.npy" 79 | datapipe = InMemoryBinaryCriteoIterDataPipe 80 | else: 81 | dir_path = args.synthetic_multi_hot_criteo_path 82 | sparse_part = "sparse_multi_hot.npz" 83 | datapipe = MultiHotCriteoIterDataPipe 84 | 85 | if args.dataset_name == "criteo_kaggle": 86 | # criteo_kaggle has no validation set, so use 2nd half of training set for now. 87 | # Setting stage to "test" will get the 2nd half of the dataset. 88 | # Setting root_name to "train" reads from the training set file. 89 | (root_name, stage) = ( 90 | ("train", "train") if stage == "train" else ("train", "test") 91 | ) 92 | stage_files: List[List[str]] = [ 93 | [os.path.join(dir_path, f"{root_name}_dense.npy")], 94 | [os.path.join(dir_path, f"{root_name}_{sparse_part}")], 95 | [os.path.join(dir_path, f"{root_name}_labels.npy")], 96 | ] 97 | # criteo_1tb code path uses below two conditionals 98 | elif stage == "train": 99 | stage_files: List[List[str]] = [ 100 | [os.path.join(dir_path, f"day_{i}_dense.npy") for i in range(DAYS - 1)], 101 | [os.path.join(dir_path, f"day_{i}_{sparse_part}") for i in range(DAYS - 1)], 102 | [os.path.join(dir_path, f"day_{i}_labels.npy") for i in range(DAYS - 1)], 103 | ] 104 | elif stage in ["val", "test"]: 105 | stage_files: List[List[str]] = [ 106 | [os.path.join(dir_path, f"day_{DAYS-1}_dense.npy")], 107 | [os.path.join(dir_path, f"day_{DAYS-1}_{sparse_part}")], 108 | [os.path.join(dir_path, f"day_{DAYS-1}_labels.npy")], 109 | ] 110 | if stage in ["val", "test"] and args.test_batch_size is not None: 111 | batch_size = args.test_batch_size 112 | else: 113 | batch_size = args.batch_size 114 | dataloader = DataLoader( 115 | datapipe( 116 | stage, 117 | *stage_files, # pyre-ignore[6] 118 | batch_size=batch_size, 119 | rank=dist.get_rank(), 120 | world_size=dist.get_world_size(), 121 | drop_last=args.drop_last_training_batch if stage == "train" else False, 122 | shuffle_batches=args.shuffle_batches, 123 | shuffle_training_set=args.shuffle_training_set, 124 | shuffle_training_set_random_seed=args.seed, 125 | mmap_mode=args.mmap_mode, 126 | hashes=( 127 | args.num_embeddings_per_feature 128 | if args.num_embeddings is None 129 | else ([args.num_embeddings] * CAT_FEATURE_COUNT) 130 | ), 131 | ), 132 | batch_size=None, 133 | pin_memory=args.pin_memory, 134 | collate_fn=lambda x: x, 135 | ) 136 | return dataloader 137 | 138 | 139 | def get_dataloader(args: argparse.Namespace, backend: str, stage: str) -> DataLoader: 140 | """ 141 | Gets desired dataloader from dlrm_main command line options. Currently, this 142 | function is able to return either a DataLoader wrapped around a RandomRecDataset or 143 | a Dataloader wrapped around an InMemoryBinaryCriteoIterDataPipe. 144 | 145 | Args: 146 | args (argparse.Namespace): Command line options supplied to dlrm_main.py's main 147 | function. 148 | backend (str): "nccl" or "gloo". 149 | stage (str): "train", "val", or "test". 150 | 151 | Returns: 152 | dataloader (DataLoader): PyTorch dataloader for the specified options. 153 | 154 | """ 155 | stage = stage.lower() 156 | if stage not in STAGES: 157 | raise ValueError(f"Supplied stage was {stage}. Must be one of {STAGES}.") 158 | 159 | args.pin_memory = ( 160 | (backend == "nccl") if not hasattr(args, "pin_memory") else args.pin_memory 161 | ) 162 | 163 | if ( 164 | args.in_memory_binary_criteo_path is None 165 | and args.synthetic_multi_hot_criteo_path is None 166 | ): 167 | return _get_random_dataloader(args, stage) 168 | else: 169 | return _get_in_memory_dataloader(args, stage) 170 | -------------------------------------------------------------------------------- /torchrec_dlrm/data/multi_hot_criteo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import zipfile 8 | from typing import Dict, Iterator, List, Optional 9 | 10 | import numpy as np 11 | import torch 12 | from iopath.common.file_io import PathManager, PathManagerFactory 13 | from pyre_extensions import none_throws 14 | from torch.utils.data import IterableDataset 15 | from torchrec.datasets.criteo import CAT_FEATURE_COUNT, DEFAULT_CAT_NAMES 16 | from torchrec.datasets.utils import Batch, PATH_MANAGER_KEY 17 | from torchrec.sparse.jagged_tensor import KeyedJaggedTensor 18 | 19 | 20 | class MultiHotCriteoIterDataPipe(IterableDataset): 21 | """ 22 | Datapipe designed to operate over the MLPerf DLRM v2 synthetic multi-hot dataset. 23 | This dataset can be created by following the steps in 24 | torchrec_dlrm/scripts/materialize_synthetic_multihot_dataset.py. 25 | Each rank reads only the data for the portion of the dataset it is responsible for. 26 | 27 | Args: 28 | stage (str): "train", "val", or "test". 29 | dense_paths (List[str]): List of path strings to dense npy files. 30 | sparse_paths (List[str]): List of path strings to multi-hot sparse npz files. 31 | labels_paths (List[str]): List of path strings to labels npy files. 32 | batch_size (int): batch size. 33 | rank (int): rank. 34 | world_size (int): world size. 35 | drop_last (Optional[bool]): Whether to drop the last batch if it is incomplete. 36 | shuffle_batches (bool): Whether to shuffle batches 37 | shuffle_training_set (bool): Whether to shuffle all samples in the dataset. 38 | shuffle_training_set_random_seed (int): The random generator seed used when 39 | shuffling the training set. 40 | hashes (Optional[int]): List of max categorical feature value for each feature. 41 | Length of this list should be CAT_FEATURE_COUNT. 42 | path_manager_key (str): Path manager key used to load from different 43 | filesystems. 44 | 45 | Example:: 46 | 47 | datapipe = MultiHotCriteoIterDataPipe( 48 | dense_paths=["day_0_dense.npy"], 49 | sparse_paths=["day_0_sparse_multi_hot.npz"], 50 | labels_paths=["day_0_labels.npy"], 51 | batch_size=1024, 52 | rank=torch.distributed.get_rank(), 53 | world_size=torch.distributed.get_world_size(), 54 | ) 55 | batch = next(iter(datapipe)) 56 | """ 57 | 58 | def __init__( 59 | self, 60 | stage: str, 61 | dense_paths: List[str], 62 | sparse_paths: List[str], 63 | labels_paths: List[str], 64 | batch_size: int, 65 | rank: int, 66 | world_size: int, 67 | drop_last: Optional[bool] = False, 68 | shuffle_batches: bool = False, 69 | shuffle_training_set: bool = False, 70 | shuffle_training_set_random_seed: int = 0, 71 | mmap_mode: bool = False, 72 | hashes: Optional[List[int]] = None, 73 | path_manager_key: str = PATH_MANAGER_KEY, 74 | ) -> None: 75 | self.stage = stage 76 | self.dense_paths = dense_paths 77 | self.sparse_paths = sparse_paths 78 | self.labels_paths = labels_paths 79 | self.batch_size = batch_size 80 | self.rank = rank 81 | self.world_size = world_size 82 | self.drop_last = drop_last 83 | self.shuffle_batches = shuffle_batches 84 | self.shuffle_training_set = shuffle_training_set 85 | np.random.seed(shuffle_training_set_random_seed) 86 | self.mmap_mode = mmap_mode 87 | # hashes are not used because they were already applied in the 88 | # script that generates the multi-hot dataset. 89 | self.hashes: np.ndarray = np.array(hashes).reshape((CAT_FEATURE_COUNT, 1)) 90 | self.path_manager_key = path_manager_key 91 | self.path_manager: PathManager = PathManagerFactory().get(path_manager_key) 92 | 93 | if shuffle_training_set and stage == "train": 94 | # Currently not implemented for the materialized multi-hot dataset. 95 | self._shuffle_and_load_data_for_rank() 96 | else: 97 | m = "r" if mmap_mode else None 98 | self.dense_arrs: List[np.ndarray] = [ 99 | np.load(f, mmap_mode=m) for f in self.dense_paths 100 | ] 101 | self.labels_arrs: List[np.ndarray] = [ 102 | np.load(f, mmap_mode=m) for f in self.labels_paths 103 | ] 104 | self.sparse_arrs: List = [] 105 | for sparse_path in self.sparse_paths: 106 | multi_hot_ids_l = [] 107 | for feat_id_num in range(CAT_FEATURE_COUNT): 108 | multi_hot_ft_ids = self._load_from_npz( 109 | sparse_path, f"{feat_id_num}.npy" 110 | ) 111 | multi_hot_ids_l.append(multi_hot_ft_ids) 112 | self.sparse_arrs.append(multi_hot_ids_l) 113 | len_d0 = len(self.dense_arrs[0]) 114 | second_half_start_index = int(len_d0 // 2 + len_d0 % 2) 115 | if stage == "val": 116 | self.dense_arrs[0] = self.dense_arrs[0][:second_half_start_index, :] 117 | self.labels_arrs[0] = self.labels_arrs[0][:second_half_start_index, :] 118 | self.sparse_arrs[0] = [ 119 | feats[:second_half_start_index, :] for feats in self.sparse_arrs[0] 120 | ] 121 | elif stage == "test": 122 | self.dense_arrs[0] = self.dense_arrs[0][second_half_start_index:, :] 123 | self.labels_arrs[0] = self.labels_arrs[0][second_half_start_index:, :] 124 | self.sparse_arrs[0] = [ 125 | feats[second_half_start_index:, :] for feats in self.sparse_arrs[0] 126 | ] 127 | # When mmap_mode is enabled, sparse features are hashed when 128 | # samples are batched in def __iter__. Otherwise, the dataset has been 129 | # preloaded with sparse features hashed in the preload stage, here: 130 | # if not self.mmap_mode and self.hashes is not None: 131 | # for k, _ in enumerate(self.sparse_arrs): 132 | # self.sparse_arrs[k] = [ 133 | # feat % hash 134 | # for (feat, hash) in zip(self.sparse_arrs[k], self.hashes) 135 | # ] 136 | 137 | self.num_rows_per_file: List[int] = list(map(len, self.dense_arrs)) 138 | total_rows = sum(self.num_rows_per_file) 139 | self.num_full_batches: int = ( 140 | total_rows // batch_size // self.world_size * self.world_size 141 | ) 142 | self.last_batch_sizes: np.ndarray = np.array( 143 | [0 for _ in range(self.world_size)] 144 | ) 145 | remainder = total_rows % (self.world_size * batch_size) 146 | if not self.drop_last and 0 < remainder: 147 | if remainder < self.world_size: 148 | self.num_full_batches -= self.world_size 149 | self.last_batch_sizes += batch_size 150 | else: 151 | self.last_batch_sizes += remainder // self.world_size 152 | self.last_batch_sizes[: remainder % self.world_size] += 1 153 | 154 | self.multi_hot_sizes: List[int] = [ 155 | multi_hot_feat.shape[-1] for multi_hot_feat in self.sparse_arrs[0] 156 | ] 157 | 158 | # These values are the same for the KeyedJaggedTensors in all batches, so they 159 | # are computed once here. This avoids extra work from the KeyedJaggedTensor sync 160 | # functions. 161 | self.keys: List[str] = DEFAULT_CAT_NAMES 162 | self.index_per_key: Dict[str, int] = { 163 | key: i for (i, key) in enumerate(self.keys) 164 | } 165 | 166 | def _load_from_npz(self, fname, npy_name): 167 | # figure out offset of .npy in .npz 168 | zf = zipfile.ZipFile(fname) 169 | info = zf.NameToInfo[npy_name] 170 | assert info.compress_type == 0 171 | zf.fp.seek(info.header_offset + len(info.FileHeader()) + 20) 172 | # read .npy header 173 | zf.open(npy_name, "r") 174 | version = np.lib.format.read_magic(zf.fp) 175 | shape, fortran_order, dtype = np.lib.format._read_array_header(zf.fp, version) 176 | assert ( 177 | dtype == "int32" 178 | ), f"sparse multi-hot dtype is {dtype} but should be int32" 179 | offset = zf.fp.tell() 180 | # create memmap 181 | return np.memmap( 182 | zf.filename, 183 | dtype=dtype, 184 | shape=shape, 185 | order="F" if fortran_order else "C", 186 | mode="r", 187 | offset=offset, 188 | ) 189 | 190 | def _np_arrays_to_batch( 191 | self, 192 | dense: np.ndarray, 193 | sparse: List[np.ndarray], 194 | labels: np.ndarray, 195 | ) -> Batch: 196 | if self.shuffle_batches: 197 | # Shuffle all 3 in unison 198 | shuffler = np.random.permutation(len(dense)) 199 | sparse = [multi_hot_ft[shuffler, :] for multi_hot_ft in sparse] 200 | dense = dense[shuffler] 201 | labels = labels[shuffler] 202 | 203 | batch_size = len(dense) 204 | lengths = torch.ones((CAT_FEATURE_COUNT * batch_size), dtype=torch.int32) 205 | for k, multi_hot_size in enumerate(self.multi_hot_sizes): 206 | lengths[k * batch_size : (k + 1) * batch_size] = multi_hot_size 207 | offsets = torch.cumsum(torch.concat((torch.tensor([0]), lengths)), dim=0) 208 | length_per_key = [ 209 | batch_size * multi_hot_size for multi_hot_size in self.multi_hot_sizes 210 | ] 211 | offset_per_key = torch.cumsum( 212 | torch.concat((torch.tensor([0]), torch.tensor(length_per_key))), dim=0 213 | ) 214 | values = torch.concat([torch.from_numpy(feat).flatten() for feat in sparse]) 215 | return Batch( 216 | dense_features=torch.from_numpy(dense.copy()), 217 | sparse_features=KeyedJaggedTensor( 218 | keys=self.keys, 219 | values=values, 220 | lengths=lengths, 221 | offsets=offsets, 222 | stride=batch_size, 223 | length_per_key=length_per_key, 224 | offset_per_key=offset_per_key.tolist(), 225 | index_per_key=self.index_per_key, 226 | ), 227 | labels=torch.from_numpy(labels.reshape(-1).copy()), 228 | ) 229 | 230 | def __iter__(self) -> Iterator[Batch]: 231 | # Invariant: buffer never contains more than batch_size rows. 232 | buffer: Optional[List[np.ndarray]] = None 233 | 234 | def append_to_buffer( 235 | dense: np.ndarray, 236 | sparse: List[np.ndarray], 237 | labels: np.ndarray, 238 | ) -> None: 239 | nonlocal buffer 240 | if buffer is None: 241 | buffer = [dense, sparse, labels] 242 | else: 243 | buffer[0] = np.concatenate((buffer[0], dense)) 244 | buffer[1] = [np.concatenate((b, s)) for b, s in zip(buffer[1], sparse)] 245 | buffer[2] = np.concatenate((buffer[2], labels)) 246 | 247 | # Maintain a buffer that can contain up to batch_size rows. Fill buffer as 248 | # much as possible on each iteration. Only return a new batch when batch_size 249 | # rows are filled. 250 | file_idx = 0 251 | row_idx = 0 252 | batch_idx = 0 253 | buffer_row_count = 0 254 | cur_batch_size = ( 255 | self.batch_size if self.num_full_batches > 0 else self.last_batch_sizes[0] 256 | ) 257 | while ( 258 | batch_idx 259 | < self.num_full_batches + (self.last_batch_sizes[0] > 0) * self.world_size 260 | ): 261 | if buffer_row_count == cur_batch_size or file_idx == len(self.dense_arrs): 262 | if batch_idx % self.world_size == self.rank: 263 | yield self._np_arrays_to_batch(*none_throws(buffer)) 264 | buffer = None 265 | buffer_row_count = 0 266 | batch_idx += 1 267 | if 0 <= batch_idx - self.num_full_batches < self.world_size and ( 268 | self.last_batch_sizes[0] > 0 269 | ): 270 | cur_batch_size = self.last_batch_sizes[ 271 | batch_idx - self.num_full_batches 272 | ] 273 | else: 274 | rows_to_get = min( 275 | cur_batch_size - buffer_row_count, 276 | self.num_rows_per_file[file_idx] - row_idx, 277 | ) 278 | buffer_row_count += rows_to_get 279 | slice_ = slice(row_idx, row_idx + rows_to_get) 280 | 281 | if batch_idx % self.world_size == self.rank: 282 | dense_inputs = self.dense_arrs[file_idx][slice_, :] 283 | sparse_inputs = [ 284 | feats[slice_, :] for feats in self.sparse_arrs[file_idx] 285 | ] 286 | target_labels = self.labels_arrs[file_idx][slice_, :] 287 | 288 | # if self.mmap_mode and self.hashes is not None: 289 | # sparse_inputs = [ 290 | # feats % hash 291 | # for (feats, hash) in zip(sparse_inputs, self.hashes) 292 | # ] 293 | 294 | append_to_buffer( 295 | dense_inputs, 296 | sparse_inputs, 297 | target_labels, 298 | ) 299 | row_idx += rows_to_get 300 | 301 | if row_idx >= self.num_rows_per_file[file_idx]: 302 | file_idx += 1 303 | row_idx = 0 304 | 305 | def __len__(self) -> int: 306 | return self.num_full_batches // self.world_size + (self.last_batch_sizes[0] > 0) 307 | -------------------------------------------------------------------------------- /torchrec_dlrm/dlrm_main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import itertools 9 | import os 10 | import sys 11 | from dataclasses import dataclass, field 12 | from enum import Enum 13 | from typing import Iterator, List, Optional 14 | 15 | import torch 16 | import torchmetrics as metrics 17 | from pyre_extensions import none_throws 18 | from torch import distributed as dist 19 | from torch.utils.data import DataLoader 20 | from torchrec import EmbeddingBagCollection 21 | from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES 22 | from torchrec.distributed import TrainPipelineSparseDist 23 | from torchrec.distributed.comm import get_local_size 24 | from torchrec.distributed.model_parallel import ( 25 | DistributedModelParallel, 26 | get_default_sharders, 27 | ) 28 | from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology 29 | from torchrec.distributed.planner.storage_reservations import ( 30 | HeuristicalStorageReservation, 31 | ) 32 | from torchrec.models.dlrm import DLRM, DLRM_DCN, DLRM_Projection, DLRMTrain 33 | from torchrec.modules.embedding_configs import EmbeddingBagConfig 34 | from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward 35 | from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper 36 | from torchrec.optim.optimizers import in_backward_optimizer_filter 37 | from tqdm import tqdm 38 | 39 | # OSS import 40 | try: 41 | # pyre-ignore[21] 42 | # @manual=//ai_codesign/benchmarks/dlrm/torchrec_dlrm/data:dlrm_dataloader 43 | from data.dlrm_dataloader import get_dataloader 44 | 45 | # pyre-ignore[21] 46 | # @manual=//ai_codesign/benchmarks/dlrm/torchrec_dlrm:lr_scheduler 47 | from lr_scheduler import LRPolicyScheduler 48 | 49 | # pyre-ignore[21] 50 | # @manual=//ai_codesign/benchmarks/dlrm/torchrec_dlrm:multi_hot 51 | from multi_hot import Multihot, RestartableMap 52 | except ImportError: 53 | pass 54 | 55 | # internal import 56 | try: 57 | from .data.dlrm_dataloader import get_dataloader # noqa F811 58 | from .lr_scheduler import LRPolicyScheduler # noqa F811 59 | from .multi_hot import Multihot, RestartableMap # noqa F811 60 | except ImportError: 61 | pass 62 | 63 | TRAIN_PIPELINE_STAGES = 3 # Number of stages in TrainPipelineSparseDist. 64 | 65 | 66 | class InteractionType(Enum): 67 | ORIGINAL = "original" 68 | DCN = "dcn" 69 | PROJECTION = "projection" 70 | 71 | def __str__(self): 72 | return self.value 73 | 74 | 75 | def parse_args(argv: List[str]) -> argparse.Namespace: 76 | parser = argparse.ArgumentParser(description="torchrec dlrm example trainer") 77 | parser.add_argument( 78 | "--epochs", 79 | type=int, 80 | default=1, 81 | help="number of epochs to train", 82 | ) 83 | parser.add_argument( 84 | "--batch_size", 85 | type=int, 86 | default=32, 87 | help="batch size to use for training", 88 | ) 89 | parser.add_argument( 90 | "--drop_last_training_batch", 91 | dest="drop_last_training_batch", 92 | action="store_true", 93 | help="Drop the last non-full training batch", 94 | ) 95 | parser.add_argument( 96 | "--test_batch_size", 97 | type=int, 98 | default=None, 99 | help="batch size to use for validation and testing", 100 | ) 101 | parser.add_argument( 102 | "--limit_train_batches", 103 | type=int, 104 | default=None, 105 | help="number of train batches", 106 | ) 107 | parser.add_argument( 108 | "--limit_val_batches", 109 | type=int, 110 | default=None, 111 | help="number of validation batches", 112 | ) 113 | parser.add_argument( 114 | "--limit_test_batches", 115 | type=int, 116 | default=None, 117 | help="number of test batches", 118 | ) 119 | parser.add_argument( 120 | "--dataset_name", 121 | type=str, 122 | choices=["criteo_1t", "criteo_kaggle"], 123 | default="criteo_1t", 124 | help="dataset for experiment, current support criteo_1tb, criteo_kaggle", 125 | ) 126 | parser.add_argument( 127 | "--num_embeddings", 128 | type=int, 129 | default=100_000, 130 | help="max_ind_size. The number of embeddings in each embedding table. Defaults" 131 | " to 100_000 if num_embeddings_per_feature is not supplied.", 132 | ) 133 | parser.add_argument( 134 | "--num_embeddings_per_feature", 135 | type=str, 136 | default=None, 137 | help="Comma separated max_ind_size per sparse feature. The number of embeddings" 138 | " in each embedding table. 26 values are expected for the Criteo dataset.", 139 | ) 140 | parser.add_argument( 141 | "--dense_arch_layer_sizes", 142 | type=str, 143 | default="512,256,64", 144 | help="Comma separated layer sizes for dense arch.", 145 | ) 146 | parser.add_argument( 147 | "--over_arch_layer_sizes", 148 | type=str, 149 | default="512,512,256,1", 150 | help="Comma separated layer sizes for over arch.", 151 | ) 152 | parser.add_argument( 153 | "--embedding_dim", 154 | type=int, 155 | default=64, 156 | help="Size of each embedding.", 157 | ) 158 | parser.add_argument( 159 | "--interaction_branch1_layer_sizes", 160 | type=str, 161 | default="2048,2048", 162 | help="Comma separated layer sizes for interaction branch1 (only on dlrm with projection).", 163 | ) 164 | parser.add_argument( 165 | "--interaction_branch2_layer_sizes", 166 | type=str, 167 | default="2048,2048", 168 | help="Comma separated layer sizes for interaction branch2 (only on dlrm with projection).", 169 | ) 170 | parser.add_argument( 171 | "--dcn_num_layers", 172 | type=int, 173 | default=3, 174 | help="Number of DCN layers in interaction layer (only on dlrm with DCN).", 175 | ) 176 | parser.add_argument( 177 | "--dcn_low_rank_dim", 178 | type=int, 179 | default=512, 180 | help="Low rank dimension for DCN in interaction layer (only on dlrm with DCN).", 181 | ) 182 | parser.add_argument( 183 | "--undersampling_rate", 184 | type=float, 185 | help="Desired proportion of zero-labeled samples to retain (i.e. undersampling zero-labeled rows)." 186 | " Ex. 0.3 indicates only 30pct of the rows with label 0 will be kept." 187 | " All rows with label 1 will be kept. Value should be between 0 and 1." 188 | " When not supplied, no undersampling occurs.", 189 | ) 190 | parser.add_argument( 191 | "--seed", 192 | type=int, 193 | help="Random seed for reproducibility.", 194 | ) 195 | parser.add_argument( 196 | "--pin_memory", 197 | dest="pin_memory", 198 | action="store_true", 199 | help="Use pinned memory when loading data.", 200 | ) 201 | parser.add_argument( 202 | "--mmap_mode", 203 | dest="mmap_mode", 204 | action="store_true", 205 | help="--mmap_mode mmaps the dataset." 206 | " That is, the dataset is kept on disk but is accessed as if it were in memory." 207 | " --mmap_mode is intended mostly for faster debugging. Use --mmap_mode to bypass" 208 | " preloading the dataset when preloading takes too long or when there is " 209 | " insufficient memory available to load the full dataset.", 210 | ) 211 | parser.add_argument( 212 | "--in_memory_binary_criteo_path", 213 | type=str, 214 | default=None, 215 | help="Directory path containing the Criteo dataset npy files.", 216 | ) 217 | parser.add_argument( 218 | "--synthetic_multi_hot_criteo_path", 219 | type=str, 220 | default=None, 221 | help="Directory path containing the MLPerf v2 synthetic multi-hot dataset npz files.", 222 | ) 223 | parser.add_argument( 224 | "--learning_rate", 225 | type=float, 226 | default=15.0, 227 | help="Learning rate.", 228 | ) 229 | parser.add_argument( 230 | "--eps", 231 | type=float, 232 | default=1e-8, 233 | help="Epsilon for Adagrad optimizer.", 234 | ) 235 | parser.add_argument( 236 | "--shuffle_batches", 237 | dest="shuffle_batches", 238 | action="store_true", 239 | help="Shuffle each batch during training.", 240 | ) 241 | parser.add_argument( 242 | "--shuffle_training_set", 243 | dest="shuffle_training_set", 244 | action="store_true", 245 | help="Shuffle the training set in memory. This will override mmap_mode", 246 | ) 247 | parser.add_argument( 248 | "--validation_freq_within_epoch", 249 | type=int, 250 | default=None, 251 | help="Frequency at which validation will be run within an epoch.", 252 | ) 253 | parser.set_defaults( 254 | pin_memory=None, 255 | mmap_mode=None, 256 | drop_last=None, 257 | shuffle_batches=None, 258 | shuffle_training_set=None, 259 | ) 260 | parser.add_argument( 261 | "--adagrad", 262 | dest="adagrad", 263 | action="store_true", 264 | help="Flag to determine if adagrad optimizer should be used.", 265 | ) 266 | parser.add_argument( 267 | "--interaction_type", 268 | type=InteractionType, 269 | choices=list(InteractionType), 270 | default=InteractionType.ORIGINAL, 271 | help="Determine the interaction type to be used (original, dcn, or projection)" 272 | " default is original DLRM with pairwise dot product", 273 | ) 274 | parser.add_argument( 275 | "--collect_multi_hot_freqs_stats", 276 | dest="collect_multi_hot_freqs_stats", 277 | action="store_true", 278 | help="Flag to determine whether to collect stats on freq of embedding access.", 279 | ) 280 | parser.add_argument( 281 | "--multi_hot_sizes", 282 | type=str, 283 | default=None, 284 | help="Comma separated multihot size per sparse feature. 26 values are expected for the Criteo dataset.", 285 | ) 286 | parser.add_argument( 287 | "--multi_hot_distribution_type", 288 | type=str, 289 | choices=["uniform", "pareto"], 290 | default=None, 291 | help="Multi-hot distribution options.", 292 | ) 293 | parser.add_argument("--lr_warmup_steps", type=int, default=0) 294 | parser.add_argument("--lr_decay_start", type=int, default=0) 295 | parser.add_argument("--lr_decay_steps", type=int, default=0) 296 | parser.add_argument( 297 | "--print_lr", 298 | action="store_true", 299 | help="Print learning rate every iteration.", 300 | ) 301 | parser.add_argument( 302 | "--allow_tf32", 303 | action="store_true", 304 | help="Enable TensorFloat-32 mode for matrix multiplications on A100 (or newer) GPUs.", 305 | ) 306 | parser.add_argument( 307 | "--print_sharding_plan", 308 | action="store_true", 309 | help="Print the sharding plan used for each embedding table.", 310 | ) 311 | return parser.parse_args(argv) 312 | 313 | 314 | def _evaluate( 315 | limit_batches: Optional[int], 316 | pipeline: TrainPipelineSparseDist, 317 | eval_dataloader: DataLoader, 318 | stage: str, 319 | ) -> float: 320 | """ 321 | Evaluates model. Computes and prints AUROC. Helper function for train_val_test. 322 | 323 | Args: 324 | limit_batches (Optional[int]): Limits the dataloader to the first `limit_batches` batches. 325 | pipeline (TrainPipelineSparseDist): data pipeline. 326 | eval_dataloader (DataLoader): Dataloader for either the validation set or test set. 327 | stage (str): "val" or "test". 328 | 329 | Returns: 330 | float: auroc result 331 | """ 332 | pipeline._model.eval() 333 | device = pipeline._device 334 | 335 | iterator = itertools.islice(iter(eval_dataloader), limit_batches) 336 | 337 | auroc = metrics.AUROC(task="multiclass", num_classes=2).to(device) 338 | 339 | is_rank_zero = dist.get_rank() == 0 340 | if is_rank_zero: 341 | pbar = tqdm( 342 | iter(int, 1), 343 | desc=f"Evaluating {stage} set", 344 | total=len(eval_dataloader), 345 | disable=False, 346 | ) 347 | with torch.no_grad(): 348 | while True: 349 | try: 350 | _loss, logits, labels = pipeline.progress(iterator) 351 | preds = torch.sigmoid(logits) 352 | preds_reshaped = torch.stack((1 - preds, preds), dim=1) 353 | auroc(preds_reshaped, labels) 354 | if is_rank_zero: 355 | pbar.update(1) 356 | except StopIteration: 357 | break 358 | 359 | auroc_result = auroc.compute().item() 360 | num_samples = torch.tensor(sum(map(len, auroc.target)), device=device) 361 | dist.reduce(num_samples, 0, op=dist.ReduceOp.SUM) 362 | 363 | if is_rank_zero: 364 | print(f"AUROC over {stage} set: {auroc_result}.") 365 | print(f"Number of {stage} samples: {num_samples}") 366 | return auroc_result 367 | 368 | 369 | def batched(it: Iterator, n: int): 370 | assert n >= 1 371 | for x in it: 372 | yield itertools.chain((x,), itertools.islice(it, n - 1)) 373 | 374 | 375 | def _train( 376 | pipeline: TrainPipelineSparseDist, 377 | train_dataloader: DataLoader, 378 | val_dataloader: DataLoader, 379 | epoch: int, 380 | lr_scheduler, 381 | print_lr: bool, 382 | validation_freq: Optional[int], 383 | limit_train_batches: Optional[int], 384 | limit_val_batches: Optional[int], 385 | ) -> None: 386 | """ 387 | Trains model for 1 epoch. Helper function for train_val_test. 388 | 389 | Args: 390 | pipeline (TrainPipelineSparseDist): data pipeline. 391 | train_dataloader (DataLoader): Training set's dataloader. 392 | val_dataloader (DataLoader): Validation set's dataloader. 393 | epoch (int): The number of complete passes through the training set so far. 394 | lr_scheduler (LRPolicyScheduler): Learning rate scheduler. 395 | print_lr (bool): Whether to print the learning rate every training step. 396 | validation_freq (Optional[int]): The number of training steps between validation runs within an epoch. 397 | limit_train_batches (Optional[int]): Limits the training set to the first `limit_train_batches` batches. 398 | limit_val_batches (Optional[int]): Limits the validation set to the first `limit_val_batches` batches. 399 | 400 | Returns: 401 | None. 402 | """ 403 | pipeline._model.train() 404 | 405 | iterator = itertools.islice(iter(train_dataloader), limit_train_batches) 406 | 407 | is_rank_zero = dist.get_rank() == 0 408 | if is_rank_zero: 409 | pbar = tqdm( 410 | iter(int, 1), 411 | desc=f"Epoch {epoch}", 412 | total=len(train_dataloader), 413 | disable=False, 414 | ) 415 | 416 | start_it = 0 417 | n = ( 418 | validation_freq 419 | if validation_freq 420 | else limit_train_batches 421 | if limit_train_batches 422 | else len(train_dataloader) 423 | ) 424 | for batched_iterator in batched(iterator, n): 425 | for it in itertools.count(start_it): 426 | try: 427 | if is_rank_zero and print_lr: 428 | for i, g in enumerate(pipeline._optimizer.param_groups): 429 | print(f"lr: {it} {i} {g['lr']:.6f}") 430 | pipeline.progress(batched_iterator) 431 | lr_scheduler.step() 432 | if is_rank_zero: 433 | pbar.update(1) 434 | except StopIteration: 435 | if is_rank_zero: 436 | print("Total number of iterations:", it) 437 | start_it = it 438 | break 439 | 440 | if validation_freq and start_it % validation_freq == 0: 441 | _evaluate(limit_val_batches, pipeline, val_dataloader, "val") 442 | pipeline._model.train() 443 | 444 | 445 | @dataclass 446 | class TrainValTestResults: 447 | val_aurocs: List[float] = field(default_factory=list) 448 | test_auroc: Optional[float] = None 449 | 450 | 451 | def train_val_test( 452 | args: argparse.Namespace, 453 | model: torch.nn.Module, 454 | optimizer: torch.optim.Optimizer, 455 | device: torch.device, 456 | train_dataloader: DataLoader, 457 | val_dataloader: DataLoader, 458 | test_dataloader: DataLoader, 459 | lr_scheduler: LRPolicyScheduler, 460 | ) -> TrainValTestResults: 461 | """ 462 | Train/validation/test loop. 463 | 464 | Args: 465 | args (argparse.Namespace): parsed command line args. 466 | model (torch.nn.Module): model to train. 467 | optimizer (torch.optim.Optimizer): optimizer to use. 468 | device (torch.device): device to use. 469 | train_dataloader (DataLoader): Training set's dataloader. 470 | val_dataloader (DataLoader): Validation set's dataloader. 471 | test_dataloader (DataLoader): Test set's dataloader. 472 | lr_scheduler (LRPolicyScheduler): Learning rate scheduler. 473 | 474 | Returns: 475 | TrainValTestResults. 476 | """ 477 | results = TrainValTestResults() 478 | pipeline = TrainPipelineSparseDist( 479 | model, optimizer, device, execute_all_batches=True 480 | ) 481 | 482 | for epoch in range(args.epochs): 483 | _train( 484 | pipeline, 485 | train_dataloader, 486 | val_dataloader, 487 | epoch, 488 | lr_scheduler, 489 | args.print_lr, 490 | args.validation_freq_within_epoch, 491 | args.limit_train_batches, 492 | args.limit_val_batches, 493 | ) 494 | val_auroc = _evaluate(args.limit_val_batches, pipeline, val_dataloader, "val") 495 | results.val_aurocs.append(val_auroc) 496 | 497 | test_auroc = _evaluate(args.limit_test_batches, pipeline, test_dataloader, "test") 498 | results.test_auroc = test_auroc 499 | 500 | return results 501 | 502 | 503 | def main(argv: List[str]) -> None: 504 | """ 505 | Trains, validates, and tests a Deep Learning Recommendation Model (DLRM) 506 | (https://arxiv.org/abs/1906.00091). The DLRM model contains both data parallel 507 | components (e.g. multi-layer perceptrons & interaction arch) and model parallel 508 | components (e.g. embedding tables). The DLRM model is pipelined so that dataloading, 509 | data-parallel to model-parallel comms, and forward/backward are overlapped. Can be 510 | run with either a random dataloader or an in-memory Criteo 1 TB click logs dataset 511 | (https://ailab.criteo.com/download-criteo-1tb-click-logs-dataset/). 512 | 513 | Args: 514 | argv (List[str]): command line args. 515 | 516 | Returns: 517 | None. 518 | """ 519 | args = parse_args(argv) 520 | for name, val in vars(args).items(): 521 | try: 522 | vars(args)[name] = list(map(int, val.split(","))) 523 | except (ValueError, AttributeError): 524 | pass 525 | 526 | torch.backends.cuda.matmul.allow_tf32 = args.allow_tf32 527 | 528 | if args.multi_hot_sizes is not None: 529 | assert ( 530 | args.num_embeddings_per_feature is not None 531 | and len(args.multi_hot_sizes) == len(args.num_embeddings_per_feature) 532 | or args.num_embeddings_per_feature is None 533 | and len(args.multi_hot_sizes) == len(DEFAULT_CAT_NAMES) 534 | ), "--multi_hot_sizes must be a comma delimited list the same size as the number of embedding tables." 535 | assert ( 536 | args.in_memory_binary_criteo_path is None 537 | or args.synthetic_multi_hot_criteo_path is None 538 | ), "--in_memory_binary_criteo_path and --synthetic_multi_hot_criteo_path are mutually exclusive CLI arguments." 539 | assert ( 540 | args.multi_hot_sizes is None or args.synthetic_multi_hot_criteo_path is None 541 | ), "--multi_hot_sizes is used to convert 1-hot to multi-hot. It's inapplicable with --synthetic_multi_hot_criteo_path." 542 | assert ( 543 | args.multi_hot_distribution_type is None 544 | or args.synthetic_multi_hot_criteo_path is None 545 | ), "--multi_hot_distribution_type is used to convert 1-hot to multi-hot. It's inapplicable with --synthetic_multi_hot_criteo_path." 546 | 547 | rank = int(os.environ["LOCAL_RANK"]) 548 | if torch.cuda.is_available(): 549 | device: torch.device = torch.device(f"cuda:{rank}") 550 | backend = "nccl" 551 | torch.cuda.set_device(device) 552 | else: 553 | device: torch.device = torch.device("cpu") 554 | backend = "gloo" 555 | 556 | if rank == 0: 557 | print( 558 | "PARAMS: (lr, batch_size, warmup_steps, decay_start, decay_steps): " 559 | f"{(args.learning_rate, args.batch_size, args.lr_warmup_steps, args.lr_decay_start, args.lr_decay_steps)}" 560 | ) 561 | dist.init_process_group(backend=backend) 562 | 563 | if args.num_embeddings_per_feature is not None: 564 | args.num_embeddings = None 565 | 566 | # Sets default limits for random dataloader iterations when left unspecified. 567 | if ( 568 | args.in_memory_binary_criteo_path 569 | is args.synthetic_multi_hot_criteo_path 570 | is None 571 | ): 572 | for split in ["train", "val", "test"]: 573 | attr = f"limit_{split}_batches" 574 | if getattr(args, attr) is None: 575 | setattr(args, attr, 10) 576 | 577 | train_dataloader = get_dataloader(args, backend, "train") 578 | val_dataloader = get_dataloader(args, backend, "val") 579 | test_dataloader = get_dataloader(args, backend, "test") 580 | 581 | eb_configs = [ 582 | EmbeddingBagConfig( 583 | name=f"t_{feature_name}", 584 | embedding_dim=args.embedding_dim, 585 | num_embeddings=( 586 | none_throws(args.num_embeddings_per_feature)[feature_idx] 587 | if args.num_embeddings is None 588 | else args.num_embeddings 589 | ), 590 | feature_names=[feature_name], 591 | ) 592 | for feature_idx, feature_name in enumerate(DEFAULT_CAT_NAMES) 593 | ] 594 | sharded_module_kwargs = {} 595 | if args.over_arch_layer_sizes is not None: 596 | sharded_module_kwargs["over_arch_layer_sizes"] = args.over_arch_layer_sizes 597 | 598 | if args.interaction_type == InteractionType.ORIGINAL: 599 | dlrm_model = DLRM( 600 | embedding_bag_collection=EmbeddingBagCollection( 601 | tables=eb_configs, device=torch.device("meta") 602 | ), 603 | dense_in_features=len(DEFAULT_INT_NAMES), 604 | dense_arch_layer_sizes=args.dense_arch_layer_sizes, 605 | over_arch_layer_sizes=args.over_arch_layer_sizes, 606 | dense_device=device, 607 | ) 608 | elif args.interaction_type == InteractionType.DCN: 609 | dlrm_model = DLRM_DCN( 610 | embedding_bag_collection=EmbeddingBagCollection( 611 | tables=eb_configs, device=torch.device("meta") 612 | ), 613 | dense_in_features=len(DEFAULT_INT_NAMES), 614 | dense_arch_layer_sizes=args.dense_arch_layer_sizes, 615 | over_arch_layer_sizes=args.over_arch_layer_sizes, 616 | dcn_num_layers=args.dcn_num_layers, 617 | dcn_low_rank_dim=args.dcn_low_rank_dim, 618 | dense_device=device, 619 | ) 620 | elif args.interaction_type == InteractionType.PROJECTION: 621 | dlrm_model = DLRM_Projection( 622 | embedding_bag_collection=EmbeddingBagCollection( 623 | tables=eb_configs, device=torch.device("meta") 624 | ), 625 | dense_in_features=len(DEFAULT_INT_NAMES), 626 | dense_arch_layer_sizes=args.dense_arch_layer_sizes, 627 | over_arch_layer_sizes=args.over_arch_layer_sizes, 628 | interaction_branch1_layer_sizes=args.interaction_branch1_layer_sizes, 629 | interaction_branch2_layer_sizes=args.interaction_branch2_layer_sizes, 630 | dense_device=device, 631 | ) 632 | else: 633 | raise ValueError( 634 | "Unknown interaction option set. Should be original, dcn, or projection." 635 | ) 636 | 637 | train_model = DLRMTrain(dlrm_model) 638 | embedding_optimizer = torch.optim.Adagrad if args.adagrad else torch.optim.SGD 639 | # This will apply the Adagrad optimizer in the backward pass for the embeddings (sparse_arch). This means that 640 | # the optimizer update will be applied in the backward pass, in this case through a fused op. 641 | # TorchRec will use the FBGEMM implementation of EXACT_ADAGRAD. For GPU devices, a fused CUDA kernel is invoked. For CPU, FBGEMM_GPU invokes CPU kernels 642 | # https://github.com/pytorch/FBGEMM/blob/2cb8b0dff3e67f9a009c4299defbd6b99cc12b8f/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py#L676-L678 643 | 644 | # Note that lr_decay, weight_decay and initial_accumulator_value for Adagrad optimizer in FBGEMM v0.3.2 645 | # cannot be specified below. This equivalently means that all these parameters are hardcoded to zero. 646 | optimizer_kwargs = {"lr": args.learning_rate} 647 | if args.adagrad: 648 | optimizer_kwargs["eps"] = args.eps 649 | apply_optimizer_in_backward( 650 | embedding_optimizer, 651 | train_model.model.sparse_arch.parameters(), 652 | optimizer_kwargs, 653 | ) 654 | planner = EmbeddingShardingPlanner( 655 | topology=Topology( 656 | local_world_size=get_local_size(), 657 | world_size=dist.get_world_size(), 658 | compute_device=device.type, 659 | ), 660 | batch_size=args.batch_size, 661 | # If experience OOM, increase the percentage. see 662 | # https://pytorch.org/torchrec/torchrec.distributed.planner.html#torchrec.distributed.planner.storage_reservations.HeuristicalStorageReservation 663 | storage_reservation=HeuristicalStorageReservation(percentage=0.05), 664 | ) 665 | plan = planner.collective_plan( 666 | train_model, get_default_sharders(), dist.GroupMember.WORLD 667 | ) 668 | 669 | model = DistributedModelParallel( 670 | module=train_model, 671 | device=device, 672 | plan=plan, 673 | ) 674 | if rank == 0 and args.print_sharding_plan: 675 | for collectionkey, plans in model._plan.plan.items(): 676 | print(collectionkey) 677 | for table_name, plan in plans.items(): 678 | print(table_name, "\n", plan, "\n") 679 | 680 | def optimizer_with_params(): 681 | if args.adagrad: 682 | return lambda params: torch.optim.Adagrad( 683 | params, lr=args.learning_rate, eps=args.eps 684 | ) 685 | else: 686 | return lambda params: torch.optim.SGD(params, lr=args.learning_rate) 687 | 688 | dense_optimizer = KeyedOptimizerWrapper( 689 | dict(in_backward_optimizer_filter(model.named_parameters())), 690 | optimizer_with_params(), 691 | ) 692 | optimizer = CombinedOptimizer([model.fused_optimizer, dense_optimizer]) 693 | lr_scheduler = LRPolicyScheduler( 694 | optimizer, args.lr_warmup_steps, args.lr_decay_start, args.lr_decay_steps 695 | ) 696 | 697 | if args.multi_hot_sizes is not None: 698 | multihot = Multihot( 699 | args.multi_hot_sizes, 700 | args.num_embeddings_per_feature, 701 | args.batch_size, 702 | collect_freqs_stats=args.collect_multi_hot_freqs_stats, 703 | dist_type=args.multi_hot_distribution_type, 704 | ) 705 | multihot.pause_stats_collection_during_val_and_test(model) 706 | train_dataloader = RestartableMap( 707 | multihot.convert_to_multi_hot, train_dataloader 708 | ) 709 | val_dataloader = RestartableMap(multihot.convert_to_multi_hot, val_dataloader) 710 | test_dataloader = RestartableMap(multihot.convert_to_multi_hot, test_dataloader) 711 | train_val_test( 712 | args, 713 | model, 714 | optimizer, 715 | device, 716 | train_dataloader, 717 | val_dataloader, 718 | test_dataloader, 719 | lr_scheduler, 720 | ) 721 | if args.collect_multi_hot_freqs_stats: 722 | multihot.save_freqs_stats() 723 | 724 | 725 | def invoke_main() -> None: 726 | main(sys.argv[1:]) 727 | 728 | 729 | if __name__ == "__main__": 730 | invoke_main() # pragma: no cover 731 | -------------------------------------------------------------------------------- /torchrec_dlrm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Copied from https://github.com/facebookresearch/dlrm/blob/mlperf/dlrm_s_pytorch.py 8 | 9 | import sys 10 | 11 | from torch.optim.lr_scheduler import _LRScheduler 12 | 13 | 14 | class LRPolicyScheduler(_LRScheduler): 15 | def __init__(self, optimizer, num_warmup_steps, decay_start_step, num_decay_steps): 16 | self.num_warmup_steps = num_warmup_steps 17 | self.decay_start_step = decay_start_step 18 | self.decay_end_step = decay_start_step + num_decay_steps 19 | self.num_decay_steps = num_decay_steps 20 | 21 | if self.decay_start_step < self.num_warmup_steps: 22 | sys.exit("Learning rate warmup must finish before the decay starts") 23 | 24 | super(LRPolicyScheduler, self).__init__(optimizer) 25 | 26 | def get_lr(self): 27 | step_count = self._step_count 28 | if step_count < self.num_warmup_steps: 29 | # warmup 30 | scale = 1.0 - (self.num_warmup_steps - step_count) / self.num_warmup_steps 31 | lr = [base_lr * scale for base_lr in self.base_lrs] 32 | self.last_lr = lr 33 | elif self.decay_start_step <= step_count and step_count < self.decay_end_step: 34 | # decay 35 | decayed_steps = step_count - self.decay_start_step 36 | scale = ((self.num_decay_steps - decayed_steps) / self.num_decay_steps) ** 2 37 | min_lr = 0.0000001 38 | lr = [max(min_lr, base_lr * scale) for base_lr in self.base_lrs] 39 | self.last_lr = lr 40 | else: 41 | if self.num_decay_steps > 0: 42 | # freeze at last, either because we're after decay 43 | # or because we're between warmup and decay 44 | lr = self.last_lr 45 | else: 46 | # do not adjust 47 | lr = self.base_lrs 48 | return lr 49 | -------------------------------------------------------------------------------- /torchrec_dlrm/md5sums_MLPerf_v2_synthetic_multi_hot_sparse_dataset.txt: -------------------------------------------------------------------------------- 1 | 9287c283b01087427df915257300bf36 day_0_sparse_multi_hot.npz 2 | f4fedd921a4b214b03d0dfcd31cc30e6 day_1_sparse_multi_hot.npz 3 | 2a15bdd8d25781c4cdcf5d791dfd19f9 day_2_sparse_multi_hot.npz 4 | 0341aaee2f661f9e939a39d7a6be0aea day_3_sparse_multi_hot.npz 5 | e8db54dbf5fe438ecb76fcc2d520f31a day_4_sparse_multi_hot.npz 6 | fd35a7a2bc0ba63935b4b1c742eca018 day_5_sparse_multi_hot.npz 7 | 7d5c72b6bbe8be1dfa1f69db5c7e64fd day_6_sparse_multi_hot.npz 8 | 59bcb9855243d3a5c0ae56daa18e1033 day_7_sparse_multi_hot.npz 9 | b9f7fbccae6bb9fdabf30259b5e305f0 day_8_sparse_multi_hot.npz 10 | 03da5bf484870a3c77f66befc680ad04 day_9_sparse_multi_hot.npz 11 | eb048fc4fbd7ffa7932b81a523fe5a39 day_10_sparse_multi_hot.npz 12 | a2ebee45c9836c8c8598610a6a3e9d60 day_11_sparse_multi_hot.npz 13 | 0dd59855e1a7b65c42f7e7af303c610e day_12_sparse_multi_hot.npz 14 | 7510698f3fcff9f3d7ef9dd478e637aa day_13_sparse_multi_hot.npz 15 | 562978b7b93e179f1adb9b9f5e1dc338 day_14_sparse_multi_hot.npz 16 | 042967232f016fbccf0d40d72d0b48bb day_15_sparse_multi_hot.npz 17 | 7b59170fb0e2d1e78f15cb60cea22723 day_16_sparse_multi_hot.npz 18 | 5054c54515d2574cda0f11646787df44 day_17_sparse_multi_hot.npz 19 | 28d3dbf6c70e68f01df12a4c3298f754 day_18_sparse_multi_hot.npz 20 | db7554263a1754d3e29341d0f03bc2f0 day_19_sparse_multi_hot.npz 21 | 91ee92ffb4810c26e157c1335ef4de06 day_20_sparse_multi_hot.npz 22 | 2c99fad7b146b0ba581dce34f640f44e day_21_sparse_multi_hot.npz 23 | c7ba52c5aaf24d76acca22a0cb13b737 day_22_sparse_multi_hot.npz 24 | c46b7e31ec6f2f8768fa60bdfc0f6e40 day_23_sparse_multi_hot.npz 25 | -------------------------------------------------------------------------------- /torchrec_dlrm/md5sums_preprocessed_criteo_click_logs_dataset.txt: -------------------------------------------------------------------------------- 1 | 427113b0c4d85a8fceaf793457302067 day_0_dense.npy 2 | 4db255ce4388893e7aa1dcf157077975 day_0_labels.npy 3 | 8b444e74159dbede896e2f3b5ed31ac0 day_0_sparse.npy 4 | 3afc11c56062d8bbea4df300b5a42966 day_1_dense.npy 5 | fb40746738a7c6f4ee021033bdd518c5 day_1_labels.npy 6 | 61e95a487c955b515155b31611444f32 day_1_sparse.npy 7 | 4e73d5bb330c43826665bec142c6b407 day_2_dense.npy 8 | f0adfec8191781e3f201d45f923e6ea1 day_2_labels.npy 9 | 0473d30872cd6e582c5da0272a0569f8 day_2_sparse.npy 10 | df1f3395e0da4a06aa23b2e069ff3ad9 day_3_dense.npy 11 | 69caadf4d219f18b83f3591fe76f17c7 day_3_labels.npy 12 | d6b0d02ff18da470b7ee17f97d5380e0 day_3_sparse.npy 13 | 27868a93adc66c47d4246acbad8bb689 day_4_dense.npy 14 | c4a6a16342f0770d67d689c6c173c681 day_4_labels.npy 15 | ca54008489cb84becc3f37e7b29035c7 day_4_sparse.npy 16 | e9bc6de06d09b1feebf857d9786ee15c day_5_dense.npy 17 | 9e3e17f345474cfbde5d62b543e07d6b day_5_labels.npy 18 | d1374ee84f80ea147957f8af0e12ebe4 day_5_sparse.npy 19 | 09c8bf0fd4798172e0369134ddc7204a day_6_dense.npy 20 | 945cef1132ceab8b23f4d0e269522be2 day_6_labels.npy 21 | e4df1c271e1edd72ee4658a39cca2888 day_6_sparse.npy 22 | ae718f0d6d29a8b605ae5d12fad3ffcc day_7_dense.npy 23 | 5ff5e7eef5b88b80ef03d06fc7e81bcf day_7_labels.npy 24 | cbcb7501a6b74a45dd5c028c13a4afbc day_7_sparse.npy 25 | 5a589746fd15819afbc70e2503f94b35 day_8_dense.npy 26 | 43871397750dfdc69cadcbee7e95f2bd day_8_labels.npy 27 | c1fb4369c7da27d23f4c7f97c8893250 day_8_sparse.npy 28 | 4bb86eecb92eb4e3368085c2b1bab131 day_9_dense.npy 29 | f851934555147d436131230ebbdd5609 day_9_labels.npy 30 | e4ac0fb8a030f0769541f88142c9f931 day_9_sparse.npy 31 | 7fc29f50da6c60185381ca4ad1cb2059 day_10_dense.npy 32 | e3b3f6f974c4820064db0046bbf954c8 day_10_labels.npy 33 | 1018a9ab88c4a7369325c9d6df73b411 day_10_sparse.npy 34 | df822ae73cbaa016bf7d371d87313b56 day_11_dense.npy 35 | 26219e9c89c6ce831e7da273da666df1 day_11_labels.npy 36 | f1596fc0337443a6672a864cd541fb05 day_11_sparse.npy 37 | 015968b4d9940ec9e28cc34788013d6e day_12_dense.npy 38 | f0ca7ce0ab6033cdd355df94d11c7ed7 day_12_labels.npy 39 | 03a2ebd22b01cc18b6e338de77b4103f day_12_sparse.npy 40 | 9d79239a9e976e4dd9b8839c7cbe1eba day_13_dense.npy 41 | 4b099b9200bbb490afc08b5cd63daa0e day_13_labels.npy 42 | 2b507e0f97d972ea6ada9b3af64de151 day_13_sparse.npy 43 | 9242e6c974603ec235f163f72fdbc766 day_14_dense.npy 44 | 80cae15e032ffb9eff292738ba4d0dce day_14_labels.npy 45 | 3dccc979f7c71fae45a10c98ba6c9cb7 day_14_sparse.npy 46 | 64c6c0fcd0940f7e0d7001aa945ec8f8 day_15_dense.npy 47 | a6a730d1ef55368f3f0b21d32b039662 day_15_labels.npy 48 | c852516852cc404cb40d4de8626d2ca1 day_15_sparse.npy 49 | 5c75b60e63e9cf98dec13fbb64839c10 day_16_dense.npy 50 | 5a71a29d8df1e8baf6bf28353f1588d4 day_16_labels.npy 51 | 6c838050751697a91bbf3e68ffd4a696 day_16_sparse.npy 52 | 9798bccb5a67c5eac834153ea8bbe110 day_17_dense.npy 53 | 0a814b7eb83f375dd5a555ade6908356 day_17_labels.npy 54 | 40d2bc23fbcccb3ddb1390cc5e694cf0 day_17_sparse.npy 55 | cda094dfe7f5711877a6486f9863cd4b day_18_dense.npy 56 | a4fa26ada0d4c312b7e3354de0f5ee30 day_18_labels.npy 57 | 51711de9194737813a74bfb25c0f5d30 day_18_sparse.npy 58 | 0f0b2c0ed279462cdcc6f79252fd3395 day_19_dense.npy 59 | b21ad457474b01bd3f95fc46b6b9f04b day_19_labels.npy 60 | dd4b72cd704981441d17687f526e42ae day_19_sparse.npy 61 | 95ffc084f6cafe382afe72cbcae186bc day_20_dense.npy 62 | 9555e572e8bee22d71db8c2ac121ea8a day_20_labels.npy 63 | bc9a8c79c93ea39f32230459b4c4572a day_20_sparse.npy 64 | 4680683973be5b1a890c9314cfb2e93b day_21_dense.npy 65 | 672edc866e7ff1928d15338a99e5f336 day_21_labels.npy 66 | e4a8ae42a6d46893da6edb73e7d8a3f7 day_21_sparse.npy 67 | 3d56f190639398da2bfdc33f87cd34f0 day_22_dense.npy 68 | 733da710c5981cb67d041aa1039e4e6b day_22_labels.npy 69 | 42ef88d6bb2550a88711fed6fc144846 day_22_sparse.npy 70 | cdf7af87cbc7e9b468c0be46b1767601 day_23_dense.npy 71 | dd68f93301812026ed6f58dfb0757fa7 day_23_labels.npy 72 | 0c33f1562529cc3bca7f3708e2be63c9 day_23_sparse.npy 73 | -------------------------------------------------------------------------------- /torchrec_dlrm/multi_hot.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from typing import List, Tuple 7 | 8 | import numpy as np 9 | 10 | import torch 11 | from torchrec.datasets.utils import Batch 12 | from torchrec.sparse.jagged_tensor import KeyedJaggedTensor 13 | 14 | 15 | class RestartableMap: 16 | def __init__(self, f, source): 17 | self.source = source 18 | self.func = f 19 | 20 | def __iter__(self): 21 | for x in self.source: 22 | yield self.func(x) 23 | 24 | def __len__(self): 25 | return len(self.source) 26 | 27 | 28 | class Multihot: 29 | def __init__( 30 | self, 31 | multi_hot_sizes: List[int], 32 | num_embeddings_per_feature: List[int], 33 | batch_size: int, 34 | collect_freqs_stats: bool, 35 | dist_type: str = "uniform", 36 | ): 37 | if dist_type not in {"uniform", "pareto"}: 38 | raise ValueError( 39 | "Multi-hot distribution type {} is not supported." 40 | 'Only "uniform" and "pareto" are supported.'.format(dist_type) 41 | ) 42 | self.dist_type = dist_type 43 | self.multi_hot_sizes = multi_hot_sizes 44 | self.num_embeddings_per_feature = num_embeddings_per_feature 45 | self.batch_size = batch_size 46 | 47 | # Generate 1-hot to multi-hot lookup tables, one lookup table per sparse embedding table. 48 | self.multi_hot_tables_l = self.__make_multi_hot_indices_tables( 49 | dist_type, multi_hot_sizes, num_embeddings_per_feature 50 | ) 51 | 52 | # Pooling offsets are computed once and reused. 53 | self.offsets = self.__make_offsets( 54 | multi_hot_sizes, num_embeddings_per_feature, batch_size 55 | ) 56 | 57 | # For plotting frequency access 58 | self.collect_freqs_stats = collect_freqs_stats 59 | self.model_to_track = None 60 | self.freqs_pre_hash = [] 61 | self.freqs_post_hash = [] 62 | for embs_count in num_embeddings_per_feature: 63 | self.freqs_pre_hash.append(np.zeros((embs_count))) 64 | self.freqs_post_hash.append(np.zeros((embs_count))) 65 | 66 | def save_freqs_stats(self) -> None: 67 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 68 | rank = torch.distributed.get_rank() 69 | else: 70 | rank = 0 71 | pre_dict = {str(k): e for k, e in enumerate(self.freqs_pre_hash)} 72 | np.save(f"stats_pre_hash_{rank}_{self.dist_type}.npy", pre_dict) 73 | post_dict = {str(k): e for k, e in enumerate(self.freqs_post_hash)} 74 | np.save(f"stats_post_hash_{rank}_{self.dist_type}.npy", post_dict) 75 | 76 | def pause_stats_collection_during_val_and_test( 77 | self, model: torch.nn.Module 78 | ) -> None: 79 | self.model_to_track = model 80 | 81 | def __make_multi_hot_indices_tables( 82 | self, 83 | dist_type: str, 84 | multi_hot_sizes: List[int], 85 | num_embeddings_per_feature: List[int], 86 | ) -> List[np.array]: 87 | np.random.seed( 88 | 0 89 | ) # The seed is necessary for all ranks to produce the same lookup values. 90 | multi_hot_tables_l = [] 91 | for embs_count, multi_hot_size in zip( 92 | num_embeddings_per_feature, multi_hot_sizes 93 | ): 94 | embedding_ids = np.arange(embs_count)[:, np.newaxis] 95 | if dist_type == "uniform": 96 | synthetic_sparse_ids = np.random.randint( 97 | 0, embs_count, size=(embs_count, multi_hot_size - 1) 98 | ) 99 | elif dist_type == "pareto": 100 | synthetic_sparse_ids = ( 101 | np.random.pareto( 102 | a=0.25, size=(embs_count, multi_hot_size - 1) 103 | ).astype(np.int32) 104 | % embs_count 105 | ) 106 | multi_hot_table = np.concatenate( 107 | (embedding_ids, synthetic_sparse_ids), axis=-1 108 | ) 109 | multi_hot_tables_l.append(multi_hot_table) 110 | multi_hot_tables_l = [ 111 | torch.from_numpy(multi_hot_table).int() 112 | for multi_hot_table in multi_hot_tables_l 113 | ] 114 | return multi_hot_tables_l 115 | 116 | def __make_offsets( 117 | self, 118 | multi_hot_sizes: int, 119 | num_embeddings_per_feature: List[int], 120 | batch_size: int, 121 | ) -> List[torch.Tensor]: 122 | lS_o = torch.ones( 123 | (len(num_embeddings_per_feature) * batch_size), dtype=torch.int32 124 | ) 125 | for k, multi_hot_size in enumerate(multi_hot_sizes): 126 | lS_o[k * batch_size : (k + 1) * batch_size] = multi_hot_size 127 | lS_o = torch.cumsum(torch.concat((torch.tensor([0]), lS_o)), axis=0) 128 | return lS_o 129 | 130 | def __make_new_batch( 131 | self, 132 | lS_i: torch.Tensor, 133 | batch_size: int, 134 | ) -> Tuple[torch.Tensor, torch.Tensor]: 135 | lS_i = lS_i.reshape(-1, batch_size) 136 | multi_hot_ids_l = [] 137 | for k, (sparse_data_batch_for_table, multi_hot_table) in enumerate( 138 | zip(lS_i, self.multi_hot_tables_l) 139 | ): 140 | multi_hot_ids = torch.nn.functional.embedding( 141 | sparse_data_batch_for_table, multi_hot_table 142 | ) 143 | multi_hot_ids = multi_hot_ids.reshape(-1) 144 | multi_hot_ids_l.append(multi_hot_ids) 145 | if self.collect_freqs_stats and ( 146 | self.model_to_track is None or self.model_to_track.training 147 | ): 148 | idx_pre, cnt_pre = np.unique( 149 | sparse_data_batch_for_table, return_counts=True 150 | ) 151 | idx_post, cnt_post = np.unique(multi_hot_ids, return_counts=True) 152 | self.freqs_pre_hash[k][idx_pre] += cnt_pre 153 | self.freqs_post_hash[k][idx_post] += cnt_post 154 | lS_i = torch.cat(multi_hot_ids_l) 155 | if batch_size == self.batch_size: 156 | return lS_i, self.offsets 157 | else: 158 | return lS_i, self.__make_offsets( 159 | self.multi_hot_sizes, self.num_embeddings_per_feature, batch_size 160 | ) 161 | 162 | def convert_to_multi_hot(self, batch: Batch) -> Batch: 163 | batch_size = len(batch.dense_features) 164 | lS_i = batch.sparse_features._values 165 | lS_i, lS_o = self.__make_new_batch(lS_i, batch_size) 166 | new_sparse_features = KeyedJaggedTensor.from_offsets_sync( 167 | keys=batch.sparse_features._keys, 168 | values=lS_i, 169 | offsets=lS_o, 170 | ) 171 | return Batch( 172 | dense_features=batch.dense_features, 173 | sparse_features=new_sparse_features, 174 | labels=batch.labels, 175 | ) 176 | -------------------------------------------------------------------------------- /torchrec_dlrm/requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | torchmetrics 3 | -------------------------------------------------------------------------------- /torchrec_dlrm/scripts/download_Criteo_1TB_Click_Logs_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | base_url="https://storage.googleapis.com/criteo-cail-datasets/day_" 8 | for i in {0..23}; do 9 | url="$base_url$i.gz" 10 | echo Downloading "$url" 11 | wget "$url" 12 | done 13 | -------------------------------------------------------------------------------- /torchrec_dlrm/scripts/materialize_synthetic_multihot_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import os 9 | import pathlib 10 | import shutil 11 | import sys 12 | 13 | import numpy as np 14 | import torch 15 | from torch import distributed as dist, nn 16 | from torchrec.datasets.criteo import DAYS 17 | 18 | p = pathlib.Path(__file__).absolute().parents[1].resolve() 19 | sys.path.append(os.fspath(p)) 20 | 21 | # OSS import 22 | try: 23 | # pyre-ignore[21] 24 | # @manual=//ai_codesign/benchmarks/dlrm/torchrec_dlrm:multi_hot 25 | from multi_hot import Multihot 26 | except ImportError: 27 | pass 28 | 29 | # internal import 30 | try: 31 | from .multi_hot import Multihot # noqa F811 32 | except ImportError: 33 | pass 34 | 35 | 36 | def parse_args() -> argparse.Namespace: 37 | parser = argparse.ArgumentParser( 38 | description="Script to materialize synthetic multi-hot dataset into NumPy npz file format." 39 | ) 40 | parser.add_argument( 41 | "--in_memory_binary_criteo_path", 42 | type=str, 43 | required=True, 44 | help="Path to a folder containing the binary (npy) files for the Criteo dataset." 45 | " When supplied, InMemoryBinaryCriteoIterDataPipe is used.", 46 | ) 47 | parser.add_argument( 48 | "--output_path", 49 | type=str, 50 | required=True, 51 | help="Path to outputted multi-hot sparse dataset.", 52 | ) 53 | parser.add_argument( 54 | "--copy_labels_and_dense", 55 | dest="copy_labels_and_dense", 56 | action="store_true", 57 | help="Flag to determine whether to copy labels and dense data to the output directory.", 58 | ) 59 | parser.add_argument( 60 | "--num_embeddings_per_feature", 61 | type=str, 62 | required=True, 63 | help="Comma separated max_ind_size per sparse feature. The number of embeddings" 64 | " in each embedding table. 26 values are expected for the Criteo dataset.", 65 | ) 66 | parser.add_argument( 67 | "--multi_hot_sizes", 68 | type=str, 69 | required=True, 70 | help="Comma separated multihot size per sparse feature. 26 values are expected for the Criteo dataset.", 71 | ) 72 | parser.add_argument( 73 | "--multi_hot_distribution_type", 74 | type=str, 75 | choices=["uniform", "pareto"], 76 | default="uniform", 77 | help="Multi-hot distribution options.", 78 | ) 79 | return parser.parse_args() 80 | 81 | 82 | def main() -> None: 83 | """ 84 | This script generates and saves the MLPerf v2 multi-hot dataset (4 TB in size). 85 | First, run process_Criteo_1TB_Click_Logs_dataset.sh. 86 | Then, run this script as follows: 87 | 88 | python materialize_synthetic_multihot_dataset.py \ 89 | --in_memory_binary_criteo_path $PREPROCESSED_CRITEO_1TB_CLICK_LOGS_DATASET_PATH \ 90 | --output_path $MATERIALIZED_DATASET_PATH \ 91 | --num_embeddings_per_feature 40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36 \ 92 | --multi_hot_sizes 3,2,1,2,6,1,1,1,1,7,3,8,1,6,9,5,1,1,1,12,100,27,10,3,1,1 \ 93 | --multi_hot_distribution_type uniform 94 | 95 | This script takes about 2 hours to run (can be parallelized if needed). 96 | """ 97 | args = parse_args() 98 | for name, val in vars(args).items(): 99 | try: 100 | vars(args)[name] = list(map(int, val.split(","))) 101 | except (ValueError, AttributeError): 102 | pass 103 | try: 104 | backend = "nccl" if torch.cuda.is_available() else "gloo" 105 | if not dist.is_initialized(): 106 | dist.init_process_group(backend=backend) 107 | rank = dist.get_rank() 108 | world_size = dist.get_world_size() 109 | except (KeyError, ValueError): 110 | rank = 0 111 | world_size = 1 112 | 113 | print("Generating one-hot to multi-hot lookup table.") 114 | multihot = Multihot( 115 | multi_hot_sizes=args.multi_hot_sizes, 116 | num_embeddings_per_feature=args.num_embeddings_per_feature, 117 | batch_size=1, # Doesn't matter 118 | collect_freqs_stats=False, 119 | dist_type=args.multi_hot_distribution_type, 120 | ) 121 | 122 | os.makedirs(args.output_path, exist_ok=True) 123 | 124 | for i in range(rank, DAYS, world_size): 125 | input_file_path = os.path.join( 126 | args.in_memory_binary_criteo_path, f"day_{i}_sparse.npy" 127 | ) 128 | print(f"Materializing {input_file_path}") 129 | sparse_data = np.load(input_file_path, mmap_mode="r") 130 | multi_hot_ids_dict = {} 131 | for j, (multi_hot_table, hash) in enumerate( 132 | zip(multihot.multi_hot_tables_l, args.num_embeddings_per_feature) 133 | ): 134 | sparse_tensor = torch.from_numpy(sparse_data[:, j] % hash) 135 | multi_hot_ids_dict[str(j)] = nn.functional.embedding( 136 | sparse_tensor, multi_hot_table 137 | ).numpy() 138 | output_file_path = os.path.join( 139 | args.output_path, f"day_{i}_sparse_multi_hot.npz" 140 | ) 141 | np.savez(output_file_path, **multi_hot_ids_dict) 142 | if args.copy_labels_and_dense: 143 | for part in ["labels", "dense"]: 144 | source_path = os.path.join( 145 | args.in_memory_binary_criteo_path, f"day_{i}_{part}.npy" 146 | ) 147 | output_path = os.path.join(args.output_path, f"day_{i}_{part}.npy") 148 | shutil.copyfile(source_path, output_path) 149 | print(f"Copying {source_path} to {output_path}") 150 | 151 | 152 | if __name__ == "__main__": 153 | main() 154 | -------------------------------------------------------------------------------- /torchrec_dlrm/scripts/process_Criteo_1TB_Click_Logs_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | display_help() { 8 | echo "Three command line arguments are required." 9 | echo "Example usage:" 10 | echo "bash process_Criteo_1TB_Click_Logs_dataset.sh \\" 11 | echo "./criteo_1tb/raw_input_dataset_dir \\" 12 | echo "./criteo_1tb/temp_intermediate_files_dir \\" 13 | echo "./criteo_1tb/numpy_contiguous_shuffled_output_dataset_dir" 14 | exit 1 15 | } 16 | 17 | [ -z "$1" ] && display_help 18 | [ -z "$2" ] && display_help 19 | [ -z "$3" ] && display_help 20 | 21 | # Input directory containing the raw Criteo 1TB Click Logs dataset files in tsv format. 22 | # The 24 dataset filenames in the directory should be day_{0..23} with no .tsv extension. 23 | raw_tsv_criteo_files_dir=$(readlink -m "$1") 24 | 25 | # Directory to store temporary intermediate output files created by preprocessing steps 1 and 2. 26 | temp_files_dir=$(readlink -m "$2") 27 | 28 | # Directory to store temporary intermediate output files created by preprocessing step 1. 29 | step_1_output_dir="$temp_files_dir/temp_output_of_step_1" 30 | 31 | # Directory to store temporary intermediate output files created by preprocessing step 2. 32 | step_2_output_dir="$temp_files_dir/temp_output_of_step_2" 33 | 34 | # Directory to store the final preprocessed Criteo 1TB Click Logs dataset. 35 | step_3_output_dir=$(readlink -m "$3") 36 | 37 | # Step 1. Split the dataset into 3 sets of 24 numpy files: 38 | # day_{0..23}_dense.npy, day_{0..23}_labels.npy, and day_{0..23}_sparse.npy (~24hrs) 39 | set -x 40 | mkdir -p "$step_1_output_dir" 41 | date 42 | python -m torchrec.datasets.scripts.npy_preproc_criteo --input_dir "$raw_tsv_criteo_files_dir" --output_dir "$step_1_output_dir" || exit 43 | 44 | # Step 2. Convert all sparse indices in day_{0..23}_sparse.npy to contiguous indices and save the output. 45 | # The output filenames are day_{0..23}_sparse_contig_freq.npy 46 | mkdir -p "$step_2_output_dir" 47 | date 48 | python -m torchrec.datasets.scripts.contiguous_preproc_criteo --input_dir "$step_1_output_dir" --output_dir "$step_2_output_dir" --frequency_threshold 0 || exit 49 | 50 | date 51 | for i in {0..23} 52 | do 53 | name="$step_2_output_dir/day_$i""_sparse_contig_freq.npy" 54 | renamed="$step_2_output_dir/day_$i""_sparse.npy" 55 | echo "Renaming $name to $renamed" 56 | mv "$name" "$renamed" 57 | done 58 | 59 | # Step 3. Shuffle the dataset's samples in days 0 through 22. (~20hrs) 60 | # Day 23's samples are not shuffled and will be used for the validation set and test set. 61 | mkdir -p "$step_3_output_dir" 62 | date 63 | python -m torchrec.datasets.scripts.shuffle_preproc_criteo --input_dir_labels_and_dense "$step_1_output_dir" --input_dir_sparse "$step_2_output_dir" --output_dir_shuffled "$step_3_output_dir" --random_seed 0 || exit 64 | date 65 | -------------------------------------------------------------------------------- /torchrec_dlrm/tests/test_dlrm_main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import tempfile 9 | import unittest 10 | import uuid 11 | 12 | from torch.distributed.launcher.api import elastic_launch, LaunchConfig 13 | from torchrec import test_utils 14 | from torchrec.datasets.test_utils.criteo_test_utils import CriteoTest 15 | 16 | from ..dlrm_main import main 17 | 18 | 19 | class MainTest(unittest.TestCase): 20 | @classmethod 21 | def _run_trainer_random(cls) -> None: 22 | main( 23 | [ 24 | "--limit_train_batches", 25 | "10", 26 | "--limit_val_batches", 27 | "8", 28 | "--limit_test_batches", 29 | "6", 30 | "--over_arch_layer_sizes", 31 | "8,1", 32 | "--dense_arch_layer_sizes", 33 | "8,8", 34 | "--embedding_dim", 35 | "8", 36 | "--num_embeddings", 37 | "8", 38 | ] 39 | ) 40 | 41 | @test_utils.skip_if_asan 42 | def test_main_function(self) -> None: 43 | with tempfile.TemporaryDirectory() as tmpdir: 44 | lc = LaunchConfig( 45 | min_nodes=1, 46 | max_nodes=1, 47 | nproc_per_node=2, 48 | run_id=str(uuid.uuid4()), 49 | rdzv_backend="c10d", 50 | rdzv_endpoint=os.path.join(tmpdir, "rdzv"), 51 | rdzv_configs={"store_type": "file"}, 52 | start_method="spawn", 53 | monitor_interval=1, 54 | max_restarts=0, 55 | ) 56 | 57 | elastic_launch(config=lc, entrypoint=self._run_trainer_random)() 58 | 59 | @classmethod 60 | def _run_trainer_criteo_in_memory(cls) -> None: 61 | with CriteoTest._create_dataset_npys( 62 | num_rows=50, filenames=[f"day_{i}" for i in range(24)] 63 | ) as files: 64 | main( 65 | [ 66 | "--over_arch_layer_sizes", 67 | "8,1", 68 | "--dense_arch_layer_sizes", 69 | "8,8", 70 | "--embedding_dim", 71 | "8", 72 | "--num_embeddings", 73 | "64", 74 | "--batch_size", 75 | "2", 76 | "--in_memory_binary_criteo_path", 77 | os.path.dirname(files[0]), 78 | "--epochs", 79 | "2", 80 | ] 81 | ) 82 | 83 | @test_utils.skip_if_asan 84 | def test_main_function_criteo_in_memory(self) -> None: 85 | with tempfile.TemporaryDirectory() as tmpdir: 86 | lc = LaunchConfig( 87 | min_nodes=1, 88 | max_nodes=1, 89 | nproc_per_node=2, 90 | run_id=str(uuid.uuid4()), 91 | rdzv_backend="c10d", 92 | rdzv_endpoint=os.path.join(tmpdir, "rdzv"), 93 | rdzv_configs={"store_type": "file"}, 94 | start_method="spawn", 95 | monitor_interval=1, 96 | max_restarts=0, 97 | ) 98 | 99 | elastic_launch(config=lc, entrypoint=self._run_trainer_criteo_in_memory)() 100 | 101 | @classmethod 102 | def _run_trainer_dcn(cls) -> None: 103 | with CriteoTest._create_dataset_npys( 104 | num_rows=50, filenames=[f"day_{i}" for i in range(24)] 105 | ) as files: 106 | main( 107 | [ 108 | "--over_arch_layer_sizes", 109 | "8,1", 110 | "--dense_arch_layer_sizes", 111 | "8,8", 112 | "--embedding_dim", 113 | "8", 114 | "--num_embeddings", 115 | "64", 116 | "--batch_size", 117 | "2", 118 | "--in_memory_binary_criteo_path", 119 | os.path.dirname(files[0]), 120 | "--epochs", 121 | "2", 122 | "--interaction_type", 123 | "dcn", 124 | "--dcn_num_layers", 125 | "2", 126 | "--dcn_low_rank_dim", 127 | "8", 128 | ] 129 | ) 130 | 131 | @test_utils.skip_if_asan 132 | def test_main_function_dcn(self) -> None: 133 | with tempfile.TemporaryDirectory() as tmpdir: 134 | lc = LaunchConfig( 135 | min_nodes=1, 136 | max_nodes=1, 137 | nproc_per_node=2, 138 | run_id=str(uuid.uuid4()), 139 | rdzv_backend="c10d", 140 | rdzv_endpoint=os.path.join(tmpdir, "rdzv"), 141 | rdzv_configs={"store_type": "file"}, 142 | start_method="spawn", 143 | monitor_interval=1, 144 | max_restarts=0, 145 | ) 146 | 147 | elastic_launch(config=lc, entrypoint=self._run_trainer_dcn)() 148 | -------------------------------------------------------------------------------- /tricks/md_embedding_bag.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | # Mixed-Dimensions Trick 8 | # 9 | # Description: Applies mixed dimension trick to embeddings to reduce 10 | # embedding sizes. 11 | # 12 | # References: 13 | # [1] Antonio Ginart, Maxim Naumov, Dheevatsa Mudigere, Jiyan Yang, James Zou, 14 | # "Mixed Dimension Embeddings with Application to Memory-Efficient Recommendation 15 | # Systems", CoRR, arXiv:1909.11810, 2019 16 | from __future__ import absolute_import, division, print_function, unicode_literals 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | 22 | def md_solver(n, alpha, d0=None, B=None, round_dim=True, k=None): 23 | """ 24 | An external facing function call for mixed-dimension assignment 25 | with the alpha power temperature heuristic 26 | Inputs: 27 | n -- (torch.LongTensor) ; Vector of num of rows for each embedding matrix 28 | alpha -- (torch.FloatTensor); Scalar, non-negative, controls dim. skew 29 | d0 -- (torch.FloatTensor); Scalar, baseline embedding dimension 30 | B -- (torch.FloatTensor); Scalar, parameter budget for embedding layer 31 | round_dim -- (bool); flag for rounding dims to nearest pow of 2 32 | k -- (torch.LongTensor) ; Vector of average number of queries per inference 33 | """ 34 | n, indices = torch.sort(n) 35 | k = k[indices] if k is not None else torch.ones(len(n)) 36 | d = alpha_power_rule(n.type(torch.float) / k, alpha, d0=d0, B=B) 37 | if round_dim: 38 | d = pow_2_round(d) 39 | undo_sort = [0] * len(indices) 40 | for i, v in enumerate(indices): 41 | undo_sort[v] = i 42 | return d[undo_sort] 43 | 44 | 45 | def alpha_power_rule(n, alpha, d0=None, B=None): 46 | if d0 is not None: 47 | lamb = d0 * (n[0].type(torch.float) ** alpha) 48 | elif B is not None: 49 | lamb = B / torch.sum(n.type(torch.float) ** (1 - alpha)) 50 | else: 51 | raise ValueError("Must specify either d0 or B") 52 | d = torch.ones(len(n)) * lamb * (n.type(torch.float) ** (-alpha)) 53 | for i in range(len(d)): 54 | if i == 0 and d0 is not None: 55 | d[i] = d0 56 | else: 57 | d[i] = 1 if d[i] < 1 else d[i] 58 | return torch.round(d).type(torch.long) 59 | 60 | 61 | def pow_2_round(dims): 62 | return 2 ** torch.round(torch.log2(dims.type(torch.float))) 63 | 64 | 65 | class PrEmbeddingBag(nn.Module): 66 | def __init__(self, num_embeddings, embedding_dim, base_dim): 67 | super(PrEmbeddingBag, self).__init__() 68 | self.embs = nn.EmbeddingBag( 69 | num_embeddings, embedding_dim, mode="sum", sparse=True 70 | ) 71 | torch.nn.init.xavier_uniform_(self.embs.weight) 72 | if embedding_dim < base_dim: 73 | self.proj = nn.Linear(embedding_dim, base_dim, bias=False) 74 | torch.nn.init.xavier_uniform_(self.proj.weight) 75 | elif embedding_dim == base_dim: 76 | self.proj = nn.Identity() 77 | else: 78 | raise ValueError( 79 | "Embedding dim " + str(embedding_dim) + " > base dim " + str(base_dim) 80 | ) 81 | 82 | def forward(self, input, offsets=None, per_sample_weights=None): 83 | return self.proj( 84 | self.embs(input, offsets=offsets, per_sample_weights=per_sample_weights) 85 | ) 86 | -------------------------------------------------------------------------------- /tricks/qr_embedding_bag.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # Quotient-Remainder Trick 7 | # 8 | # Description: Applies quotient remainder-trick to embeddings to reduce 9 | # embedding sizes. 10 | # 11 | # References: 12 | # [1] Hao-Jun Michael Shi, Dheevatsa Mudigere, Maxim Naumov, Jiyan Yang, 13 | # "Compositional Embeddings Using Complementary Partitions for Memory-Efficient 14 | # Recommendation Systems", CoRR, arXiv:1909.02107, 2019 15 | 16 | 17 | from __future__ import absolute_import, division, print_function, unicode_literals 18 | 19 | import numpy as np 20 | import torch 21 | import torch.nn as nn 22 | import torch.nn.functional as F 23 | from torch.nn.parameter import Parameter 24 | 25 | 26 | class QREmbeddingBag(nn.Module): 27 | r"""Computes sums or means over two 'bags' of embeddings, one using the quotient 28 | of the indices and the other using the remainder of the indices, without 29 | instantiating the intermediate embeddings, then performs an operation to combine these. 30 | 31 | For bags of constant length and no :attr:`per_sample_weights`, this class 32 | 33 | * with ``mode="sum"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.sum(dim=0)``, 34 | * with ``mode="mean"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.mean(dim=0)``, 35 | * with ``mode="max"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.max(dim=0)``. 36 | 37 | However, :class:`~torch.nn.EmbeddingBag` is much more time and memory efficient than using a chain of these 38 | operations. 39 | 40 | QREmbeddingBag also supports per-sample weights as an argument to the forward 41 | pass. This scales the output of the Embedding before performing a weighted 42 | reduction as specified by ``mode``. If :attr:`per_sample_weights`` is passed, the 43 | only supported ``mode`` is ``"sum"``, which computes a weighted sum according to 44 | :attr:`per_sample_weights`. 45 | 46 | Known Issues: 47 | Autograd breaks with multiple GPUs. It breaks only with multiple embeddings. 48 | 49 | Args: 50 | num_categories (int): total number of unique categories. The input indices must be in 51 | 0, 1, ..., num_categories - 1. 52 | embedding_dim (list): list of sizes for each embedding vector in each table. If ``"add"`` 53 | or ``"mult"`` operation are used, these embedding dimensions must be 54 | the same. If a single embedding_dim is used, then it will use this 55 | embedding_dim for both embedding tables. 56 | num_collisions (int): number of collisions to enforce. 57 | operation (string, optional): ``"concat"``, ``"add"``, or ``"mult". Specifies the operation 58 | to compose embeddings. ``"concat"`` concatenates the embeddings, 59 | ``"add"`` sums the embeddings, and ``"mult"`` multiplies 60 | (component-wise) the embeddings. 61 | Default: ``"mult"`` 62 | max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` 63 | is renormalized to have norm :attr:`max_norm`. 64 | norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. 65 | scale_grad_by_freq (boolean, optional): if given, this will scale gradients by the inverse of frequency of 66 | the words in the mini-batch. Default ``False``. 67 | Note: this option is not supported when ``mode="max"``. 68 | mode (string, optional): ``"sum"``, ``"mean"`` or ``"max"``. Specifies the way to reduce the bag. 69 | ``"sum"`` computes the weighted sum, taking :attr:`per_sample_weights` 70 | into consideration. ``"mean"`` computes the average of the values 71 | in the bag, ``"max"`` computes the max value over each bag. 72 | Default: ``"mean"`` 73 | sparse (bool, optional): if ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. See 74 | Notes for more details regarding sparse gradients. Note: this option is not 75 | supported when ``mode="max"``. 76 | 77 | Attributes: 78 | weight (Tensor): the learnable weights of each embedding table is the module of shape 79 | `(num_embeddings, embedding_dim)` initialized using a uniform distribution 80 | with sqrt(1 / num_categories). 81 | 82 | Inputs: :attr:`input` (LongTensor), :attr:`offsets` (LongTensor, optional), and 83 | :attr:`per_index_weights` (Tensor, optional) 84 | 85 | - If :attr:`input` is 2D of shape `(B, N)`, 86 | 87 | it will be treated as ``B`` bags (sequences) each of fixed length ``N``, and 88 | this will return ``B`` values aggregated in a way depending on the :attr:`mode`. 89 | :attr:`offsets` is ignored and required to be ``None`` in this case. 90 | 91 | - If :attr:`input` is 1D of shape `(N)`, 92 | 93 | it will be treated as a concatenation of multiple bags (sequences). 94 | :attr:`offsets` is required to be a 1D tensor containing the 95 | starting index positions of each bag in :attr:`input`. Therefore, 96 | for :attr:`offsets` of shape `(B)`, :attr:`input` will be viewed as 97 | having ``B`` bags. Empty bags (i.e., having 0-length) will have 98 | returned vectors filled by zeros. 99 | 100 | per_sample_weights (Tensor, optional): a tensor of float / double weights, or None 101 | to indicate all weights should be taken to be ``1``. If specified, :attr:`per_sample_weights` 102 | must have exactly the same shape as input and is treated as having the same 103 | :attr:`offsets`, if those are not ``None``. Only supported for ``mode='sum'``. 104 | 105 | 106 | Output shape: `(B, embedding_dim)` 107 | 108 | """ 109 | 110 | __constants__ = [ 111 | "num_categories", 112 | "embedding_dim", 113 | "num_collisions", 114 | "operation", 115 | "max_norm", 116 | "norm_type", 117 | "scale_grad_by_freq", 118 | "mode", 119 | "sparse", 120 | ] 121 | 122 | def __init__( 123 | self, 124 | num_categories, 125 | embedding_dim, 126 | num_collisions, 127 | operation="mult", 128 | max_norm=None, 129 | norm_type=2.0, 130 | scale_grad_by_freq=False, 131 | mode="mean", 132 | sparse=False, 133 | _weight=None, 134 | ): 135 | super(QREmbeddingBag, self).__init__() 136 | 137 | assert operation in ["concat", "mult", "add"], "Not valid operation!" 138 | 139 | self.num_categories = num_categories 140 | if isinstance(embedding_dim, int) or len(embedding_dim) == 1: 141 | self.embedding_dim = [embedding_dim, embedding_dim] 142 | else: 143 | self.embedding_dim = embedding_dim 144 | self.num_collisions = num_collisions 145 | self.operation = operation 146 | self.max_norm = max_norm 147 | self.norm_type = norm_type 148 | self.scale_grad_by_freq = scale_grad_by_freq 149 | 150 | if self.operation == "add" or self.operation == "mult": 151 | assert ( 152 | self.embedding_dim[0] == self.embedding_dim[1] 153 | ), "Embedding dimensions do not match!" 154 | 155 | self.num_embeddings = [ 156 | int(np.ceil(num_categories / num_collisions)), 157 | num_collisions, 158 | ] 159 | 160 | if _weight is None: 161 | self.weight_q = Parameter( 162 | torch.Tensor(self.num_embeddings[0], self.embedding_dim[0]) 163 | ) 164 | self.weight_r = Parameter( 165 | torch.Tensor(self.num_embeddings[1], self.embedding_dim[1]) 166 | ) 167 | self.reset_parameters() 168 | else: 169 | assert ( 170 | list(_weight[0].shape) 171 | == [ 172 | self.num_embeddings[0], 173 | self.embedding_dim[0], 174 | ] 175 | ), "Shape of weight for quotient table does not match num_embeddings and embedding_dim" 176 | assert ( 177 | list(_weight[1].shape) 178 | == [ 179 | self.num_embeddings[1], 180 | self.embedding_dim[1], 181 | ] 182 | ), "Shape of weight for remainder table does not match num_embeddings and embedding_dim" 183 | self.weight_q = Parameter(_weight[0]) 184 | self.weight_r = Parameter(_weight[1]) 185 | self.mode = mode 186 | self.sparse = sparse 187 | 188 | def reset_parameters(self): 189 | nn.init.uniform_(self.weight_q, np.sqrt(1 / self.num_categories)) 190 | nn.init.uniform_(self.weight_r, np.sqrt(1 / self.num_categories)) 191 | 192 | def forward(self, input, offsets=None, per_sample_weights=None): 193 | input_q = (input / self.num_collisions).long() 194 | input_r = torch.remainder(input, self.num_collisions).long() 195 | 196 | embed_q = F.embedding_bag( 197 | input_q, 198 | self.weight_q, 199 | offsets, 200 | self.max_norm, 201 | self.norm_type, 202 | self.scale_grad_by_freq, 203 | self.mode, 204 | self.sparse, 205 | per_sample_weights, 206 | ) 207 | embed_r = F.embedding_bag( 208 | input_r, 209 | self.weight_r, 210 | offsets, 211 | self.max_norm, 212 | self.norm_type, 213 | self.scale_grad_by_freq, 214 | self.mode, 215 | self.sparse, 216 | per_sample_weights, 217 | ) 218 | 219 | if self.operation == "concat": 220 | embed = torch.cat((embed_q, embed_r), dim=1) 221 | elif self.operation == "add": 222 | embed = embed_q + embed_r 223 | elif self.operation == "mult": 224 | embed = embed_q * embed_r 225 | 226 | return embed 227 | 228 | def extra_repr(self): 229 | s = "{num_embeddings}, {embedding_dim}" 230 | if self.max_norm is not None: 231 | s += ", max_norm={max_norm}" 232 | if self.norm_type != 2: 233 | s += ", norm_type={norm_type}" 234 | if self.scale_grad_by_freq is not False: 235 | s += ", scale_grad_by_freq={scale_grad_by_freq}" 236 | s += ", mode={mode}" 237 | return s.format(**self.__dict__) 238 | --------------------------------------------------------------------------------