├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── feature_request.md │ └── general_question.md ├── PULL_REQUEST_TEMPLATE.md └── workflows │ └── python_lint.yml ├── .gitignore ├── .gitmodules ├── CODEOWNERS ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── docs └── using_ET.md ├── et_replay ├── .gitignore ├── README.md ├── __init__.py ├── comm │ ├── backend │ │ ├── base_backend.py │ │ ├── pytorch_dist_backend.py │ │ └── pytorch_tpu_backend.py │ ├── commsTraceParser.py │ ├── comms_utils.py │ ├── param_profile.py │ └── profiler_trace_analysis.py ├── et_replay_utils.py ├── execution_trace.py ├── pyproject.toml ├── tests │ ├── inputs │ │ ├── 1.0.3-chakra.0.0.4 │ │ │ └── resnet_1gpu_et.json.gz │ │ ├── 1.1.0-chakra.0.0.4 │ │ │ └── resnet_2gpu_et.json.gz │ │ ├── __init__.py │ │ ├── dlrm_kineto.tar.gz │ │ ├── dlrm_pytorch_et.tar.gz │ │ ├── linear_et.json.gz │ │ ├── linear_kineto.json.gz │ │ ├── resnet_et.json.gz │ │ └── resnet_kineto.json.gz │ └── test_execution_trace.py ├── tools │ ├── comm_replay.py │ ├── et_replay.py │ └── validate_trace.py └── utils.py ├── inference └── compute │ └── pt │ └── pytorch_linear.py ├── requirements.txt ├── torchx_run.sh └── train ├── comms └── pt │ ├── README.md │ ├── comms.py │ ├── commsComputeBench.py │ ├── commsOverlapBench.py │ ├── commsTraceParser.py │ ├── commsTraceReplay.py │ ├── comms_utils.py │ ├── dlrm.py │ ├── dlrm_data.py │ ├── logger_utils.py │ ├── matmul_perf_model.py │ ├── param_profile.py │ ├── pytorch_backend_utils.py │ ├── pytorch_dist_backend.py │ ├── pytorch_tpu_backend.py │ ├── setup.py │ ├── tests │ ├── commsTraceReplay_tests.py │ ├── comms_utils_tests.py │ ├── mocks │ │ └── backend_mock.py │ └── test_utils.py │ └── triton_matmul.py ├── compute ├── pt │ ├── README.md │ ├── dataset.py │ ├── driver.py │ ├── pytorch_cutlass.py │ ├── pytorch_cvt_convs.py │ ├── pytorch_emb.py │ ├── pytorch_gemm.py │ └── pytorch_linear.py └── python │ ├── .gitignore │ ├── README.md │ ├── __init__.py │ ├── development.md │ ├── examples │ ├── __init__.py │ ├── cuda │ │ └── ncu_args.txt │ └── pytorch │ │ ├── __init__.py │ │ ├── configs │ │ ├── alex_net.json │ │ ├── aten_ops.json │ │ ├── batch_example.json │ │ ├── llama2.json │ │ ├── mm.json │ │ ├── mm_range.json │ │ ├── resnet.json │ │ ├── simple_add.json │ │ ├── simple_add_range.json │ │ ├── simple_mm.json │ │ ├── simple_mm_range.json │ │ └── split_table_batched_embeddings_ops.json │ │ ├── run_op.py │ │ └── run_op_split_table_batched_embeddings.py │ ├── lib │ ├── __init__.py │ ├── config.py │ ├── data.py │ ├── generator.py │ ├── init_helper.py │ ├── iterator.py │ ├── operator.py │ └── pytorch │ │ ├── __init__.py │ │ ├── benchmark.py │ │ ├── benchmark_helper.py │ │ ├── build_executor.py │ │ ├── config_util.py │ │ ├── cuda_util.py │ │ ├── data_impl.py │ │ ├── op_executor.py │ │ ├── operator_impl.py │ │ └── timer.py │ ├── pytorch │ ├── __init__.py │ ├── run_batch.py │ └── run_benchmark.py │ ├── requirements.txt │ ├── setup.py │ ├── test │ ├── __init__.py │ ├── pytorch │ │ ├── __init__.py │ │ └── configs │ │ │ └── test_native_basic_ops.json │ ├── test_benchmark_load.py │ ├── test_generator.py │ ├── test_register.py │ └── test_split_table_batched_embeddings_ops.py │ ├── tools │ ├── __init__.py │ └── nsys_analysis.py │ └── workloads │ ├── __init__.py │ └── pytorch │ ├── __init__.py │ ├── alex_net.py │ ├── native_basic_ops.py │ ├── resnet.py │ └── split_table_batched_embeddings_ops.py └── workloads └── README.md /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | --- 8 | 9 | ## Describe the Bug 10 | > A clear and concise description of what the bug is. 11 | 12 | ## Steps to Reproduce 13 | > Steps to reproduce the behavior. 14 | > Please include the version information where the bug was observed. 15 | 16 | ## Expected Behavior 17 | > A clear and concise description of what you expected to happen. 18 | 19 | ## Screenshots 20 | > If applicable, add screenshots to help explain your problem. 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | --- 8 | 9 | ## Problem Related to the Feature 10 | > A clear and concise description of what the problem is. 11 | 12 | ## Proposed Solution 13 | > A clear and concise description of what you want to happen. 14 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/general_question.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: General question 3 | about: Ask a question or seek clarification about the project 4 | title: '' 5 | labels: 'question' 6 | assignees: '' 7 | --- 8 | 9 | > Please provide a detailed description of your question or the information you seek. 10 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Summary 2 | Provide a concise summary of the changes introduced by this pull request. Detail the purpose and scope of the changes, referencing any relevant issues or discussions. Explain how these changes address the problem or improve the project. 3 | 4 | ## Test Plan 5 | In this section, describe the testing you have performed to verify the changes. Include: 6 | - A clear description of the testing environment. 7 | - The steps you followed to test the new features or bug fixes. 8 | - Any specific commands used during testing, along with their outputs. 9 | - A description of the results and observations from your testing. 10 | This information is crucial for reviewers to understand how the changes have been validated. 11 | 12 | ## Additional Notes 13 | Include any other notes or comments about the pull request here. This can include challenges faced, future considerations, or context that reviewers might find helpful. 14 | -------------------------------------------------------------------------------- /.github/workflows/python_lint.yml: -------------------------------------------------------------------------------- 1 | name: Python Linting 2 | 3 | on: pull_request 4 | 5 | jobs: 6 | lint-and-format: 7 | runs-on: ubuntu-latest 8 | 9 | steps: 10 | - name: Checkout Code 11 | uses: actions/checkout@v2 12 | 13 | - name: Setup Python Environment 14 | uses: actions/setup-python@v2 15 | with: 16 | python-version: '3.10' 17 | 18 | - name: Install Dependencies 19 | run: | 20 | pip install black 21 | # OpenSource Black seems not match Meta internal version 22 | # temporarily disable it until we figure out how to make 23 | # them consistent 24 | # - name: Run Black 25 | # run: black . --check 26 | 27 | - name: Run tests 28 | run: | 29 | python -m pip install -r requirements.txt 30 | python -m pip install et_replay/ 31 | python et_replay/tests/test_execution_trace.py 32 | 33 | - name: Validate imports 34 | run: | 35 | python -m pip install fbgemm-gpu 36 | python -c 'from et_replay import ExecutionTrace' 37 | python -c 'from et_replay.comm import comms_utils' 38 | python -c 'from et_replay.tools.validate_trace import TraceValidator' 39 | python -c 'from et_replay.utils import trace_handler' 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .venv/ 2 | .vscode/ 3 | __pycache__/ 4 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "train/workloads/dlrm"] 2 | path = train/workloads/dlrm 3 | url = https://github.com/facebookresearch/dlrm.git 4 | branch = dist_exp 5 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | # This is a comment. 2 | # Each line is a file pattern followed by one or more owners. 3 | 4 | # These owners will be the default owners for everything in 5 | # the repo. Unless a later match takes precedence. 6 | * @srinivas212 @kingchc @louisfeng @sunghlin @wfu-fb @shengbao-zheng @briancoutinho 7 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at . All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to PARAM_Bench 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `master`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to PARAM-Bench, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Facebook, Inc. and its affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PARAM 2 | 3 | PARAM Benchmarks is a repository of communication and compute micro-benchmarks as well as full workloads for evaluating training and inference platforms. 4 | 5 | PARAM complements two broad categories of commonly used benchmarks: 6 | 1. C++ based stand-alone compute and communication benchmarks using cuDNN, MKL, NCCL, MPI libraries - e.g., NCCL tests (https://github.com/NVIDIA/nccl-tests), OSU MPI benchmarks (https://mvapich.cse.ohio-state.edu/benchmarks/), and DeepBench (https://github.com/baidu-research/DeepBench). 7 | 2. Application benchmarks such as Deep Learning Recommendation Model (DLRM) and the broader MLPerf benchmarks. Its worth noting that while MLPerf is the de-facto industry standard for benchmarking ML applications we hope to compliment this effort with broader workloads that are of more interest to Facebook with more in-depth analysis of each within this branch of Application benchmarks. 8 | 9 | Our initial release of PARAM benchmarks focuses on AI training and comprises of: 10 | 1. Communication: PyTorch based collective benchmarks across arbitrary message sizes, effectiveness of compute-communication overlap, and DLRM communication patterns in fwd/bwd pass 11 | 2. Compute: PyTorch based GEMM, embedding lookup, and linear layer 12 | 3. DLRM: tracks the `ext_dist` branch of DRLM benchmark use Facebook's DLRM benchmark (https://github.com/facebookresearch/dlrm). In short, PARAM fully relies on DLRM benchmark for end-to-end workload evaluation; with additional extensions as required for scale-out AI training platforms. 13 | 4. PyTorch Execution Trace (ET) replay based tests: The PyTorch ET capturing capabilities, which have recently been introduced, allow for the recording of runtime information of a model at the operator level. This capability enables the creation of replay-based benchmarks (https://dl.acm.org/doi/abs/10.1145/3579371.3589072) to accurately reproduce the original performance. 14 | 15 | 16 | In essence, PARAM bridges the gap between stand-alone C++ benchmarks and PyTorch/Tensorflow based application benchmarks. This enables us to gain deep insights into the inner workings of the system architecture as well as identify framework-level overheads by stressing all subcomponents of a system. 17 | 18 | ## Version 19 | 20 | 0.1 : Initial release 21 | 22 | ## Requirements 23 | 24 | - pytorch 25 | - future 26 | - numpy 27 | - apex 28 | 29 | ## License 30 | 31 | PARAM benchmarks is released under the MIT license. Please see the [`LICENSE`](LICENSE) file for more information. 32 | 33 | ## Contributing 34 | 35 | We actively welcome your pull requests! Please see [`CONTRIBUTING.md`](CONTRIBUTING.md) and [`CODE_OF_CONDUCT.md`](CODE_OF_CONDUCT.md) for more info. 36 | -------------------------------------------------------------------------------- /docs/using_ET.md: -------------------------------------------------------------------------------- 1 | # Using Execution Trace in PARAM Benchmark 2 | 3 | This section includes how to collect Chakra Execution Trace from a PyTorch training workload, as well as how to run PARAM replay on top of the collected ET. 4 | 5 | 6 | ## Execution Trace Collection 7 | Execution Trace collection logic has to be added in the main training loop. This includes three steps: 8 | 9 | ### Step 1: Set up Execution Trace Observer 10 | The first step is to create a Execution Trace Observer object and register a. temporary file for ET store. 11 | 12 | ``` 13 | from torch.profiler import ExecutionTraceObserver 14 | 15 | et_ob = ExecutionTraceObserver() 16 | fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False) 17 | fp.close() 18 | et_ob.register_callback(fp.name) 19 | ``` 20 | 21 | ### Step 2: Define your function to dump Execution Trace 22 | You have to define a function to store/dump/upload your collected ET trace for further use. Here is an example: 23 | 24 | ``` 25 | def dump_execution_trace(tmp_et_path): 26 | et_dir.mkdir(exist_ok=True, parents=True) 27 | et_path = DUMP_DIR / f"rank-{global_rank}.et.json.gz" 28 | with open(tmp_et_path) as fin: 29 | with gzip.open(et_path, "wt") as fout: 30 | fout.writelines(fin) 31 | os.remove(tmp_et_path) 32 | print(f"Finished Rank {global_rank} ET collection at {et_path}") 33 | ``` 34 | 35 | ### Step 3: Collect Execution Trace in the training loop 36 | This is the key step to collect ET. You have to insert the collection logic into the main training loop of your workload. 37 | TWO parameters have to be set: 38 | - ET_START_ITER: the iteration to start ET collection 39 | - ET_END_ITER: the iteration to stop ET collection 40 | 41 | ``` 42 | 43 | while step < TRAINING_STEPS: 44 | ... 45 | ... 46 | # Collect Execution Trace Logic 47 | 48 | # Start ET collection 49 | if et_ob and step == ET_START_ITER: 50 | et_ob.start() 51 | 52 | # First record process group(PG) mapping 53 | pg_config_info = ( 54 | torch.distributed.distributed_c10d._world.pg_config_info 55 | ) 56 | rf_handle = torch.autograd._record_function_with_args_enter( 57 | "## process_group:init ##", json.dumps(pg_config_info) 58 | ) 59 | torch.autograd._record_function_with_args_exit(rf_handle) 60 | 61 | # Stop ET collection 62 | elif et_ob and state.step == ET_END_ITER: 63 | et_ob.stop() 64 | tmp_et_path = et_ob.get_output_file_path() 65 | et_ob.unregister_callback() 66 | dump_execution_trace(tmp_et_path) 67 | 68 | ... 69 | ... 70 | step += 1 71 | 72 | ``` 73 | 74 | Note that process group information collection is not automatically covered by ET observer, because process_group initialization happens before the main training loop. Therefore, you have to manually add pg information collection, as the code shown above. 75 | 76 | 77 | 78 | 79 | ## PARAM Comms Replay on Execution Trace 80 | Execution Trace now is fully supported in PARAM benchmark. In order to replay an ET trace, just need to specify `--trace-type=et` and the benchmark will parse your ET and replay the collective communication operators. 81 | 82 | An example command: 83 | 84 | ``` 85 | /bin/mpirun -np 8 commsTraceReplay.par --trace-path --trace-type et 86 | ``` 87 | -------------------------------------------------------------------------------- /et_replay/.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | et_replay.egg-info/ 3 | __pycache__/ 4 | -------------------------------------------------------------------------------- /et_replay/README.md: -------------------------------------------------------------------------------- 1 | # Execution Trace Replay (et_replay) 2 | `et_replay` is a tool designed for replaying Chakra Execution Traces (ET) from machine learning models. 3 | 4 | ## Installation 5 | To install `et_replay`, use the following commands: 6 | 7 | ```bash 8 | $ git clone --recurse-submodules git@github.com:facebookresearch/param.git 9 | $ conda create -n et_replay python=3.10 10 | $ conda activate et_replay 11 | $ cd param 12 | $ pip3 install -r requirements.txt 13 | $ cd et_replay 14 | $ pip3 install . 15 | ``` 16 | 17 | ## Running et_replay 18 | To use et_replay, execution traces are required. 19 | Start by collecting an execution trace using the command below. This command runs a benchmark with specific configurations and enables execution tracing. 20 | ```bash 21 | $ python -m param_bench.train.compute.python.pytorch.run_benchmark -c train/compute/python/examples/pytorch/configs/simple_add.json --et 22 | ``` 23 | 24 | After collecting the trace, replay it with the following command. Set the warm-up iteration count to at least 1 to exclude tensor transfer time to GPUs. 25 | ```bash 26 | $ python -m et_replay.tools.et_replay --input --warmup-iter 10 --iter 50 --compute --profile-replay 27 | ``` 28 | 29 | > Note: When analyzing performance values from et_replay, refer to the collected Kineto traces rather than the execution time reported by et_replay. Kineto traces are only collected when --profile-replay is provided. 30 | -------------------------------------------------------------------------------- /et_replay/__init__.py: -------------------------------------------------------------------------------- 1 | from et_replay.execution_trace import ExecutionTrace 2 | 3 | __all__ = ["ExecutionTrace"] 4 | -------------------------------------------------------------------------------- /et_replay/comm/backend/pytorch_tpu_backend.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch_xla.core.xla_model as xm # pyre-ignore[21]: 8 | import torch_xla.distributed.xla_multiprocessing as xmp # pyre-ignore[21]: 9 | 10 | from et_replay.comm.backend.base_backend import BaseBackend 11 | 12 | 13 | class PyTorchTPUBackend(BaseBackend): 14 | def sayHello(self): 15 | myhost = os.uname()[1] 16 | device = self.get_device() 17 | hw_device = self.get_hw_device() 18 | global_rank = self.get_global_rank() 19 | local_rank = self.get_local_rank() 20 | world_size = self.get_world_size() 21 | master_ip = self.bootstrap_info.master_ip 22 | print( 23 | "\tRunning on host: %s g-rank: %d, l-rank: %s world_size: %d master_ip: %s device: %s (%s)" 24 | % ( 25 | myhost, 26 | global_rank, 27 | local_rank, 28 | world_size, 29 | master_ip, 30 | device, 31 | hw_device, 32 | ) 33 | ) 34 | 35 | # Collectives 36 | def all_reduce(self, collectiveArgs, retFlag=False): 37 | retObj = xm.all_reduce(collectiveArgs.op, [collectiveArgs.ipTensor]) 38 | if collectiveArgs.asyncOp: 39 | collectiveArgs.waitObj.append(retObj) 40 | if retFlag: 41 | return retObj 42 | 43 | def reduce(self, collectiveArgs, retFlag=False): 44 | raise NotImplementedError("Func reduce: not implemented yet on TPU") 45 | 46 | def all_to_all(self, collectiveArgs, retFlag=False): 47 | retObj = xm.all_to_all(collectiveArgs.ipTensor, 0, 0, collectiveArgs.world_size) 48 | collectiveArgs.opTensor = retObj 49 | if collectiveArgs.asyncOp: 50 | collectiveArgs.waitObj.append(retObj) 51 | if retFlag: 52 | return retObj 53 | 54 | def all_to_allv(self, collectiveArgs, retFlag=False): 55 | raise NotImplementedError("Func all_to_allv: not implemented yet on TPU") 56 | 57 | def all_gather(self, collectiveArgs, retFlag=False): 58 | retObj = xm.all_gather(collectiveArgs.ipTensor, dim=0) 59 | collectiveArgs.opTensor = retObj 60 | if collectiveArgs.asyncOp: 61 | collectiveArgs.waitObj.append(retObj) 62 | if retFlag: 63 | return retObj 64 | 65 | def complete_accel_ops(self, collectiveArgs): 66 | xm.mark_step() 67 | 68 | def get_reduce_op(self, opName): 69 | if opName == "sum": 70 | return xm.REDUCE_SUM 71 | elif opName == "max": 72 | return xm.REDUCE_MAX 73 | else: 74 | return xm.REDUCE_SUM 75 | 76 | def barrier(self, collectiveArgs, name="world"): 77 | xm.rendezvous(name) 78 | 79 | # Compute functions 80 | def compute_mm(self, collectiveArgs): 81 | self.gemm(collectiveArgs) 82 | 83 | def gemm(self, collectiveArgs): 84 | collectiveArgs.MMout = torch.mm(collectiveArgs.MMin1, collectiveArgs.MMin2) 85 | 86 | # Memory related 87 | def get_mem_size(self, collectiveArgs): 88 | return ( 89 | collectiveArgs.ipTensor.nelement() * collectiveArgs.ipTensor.element_size() 90 | ) 91 | 92 | def alloc_random(self, sizeArr, curRankDevice, dtype, scaleFactor=1.0): 93 | if dtype in (torch.int32, torch.long): 94 | ipTensor = torch.randint( 95 | 0, 1000, sizeArr, device=curRankDevice, dtype=dtype 96 | ) 97 | else: 98 | ipTensor = torch.rand(sizeArr, device=curRankDevice, dtype=dtype) 99 | # ipTensor = torch.full( 100 | # sizeArr, self.get_global_rank(), device=curRankDevice, dtype=dtype 101 | # ) 102 | # print("IP: ", ipTensor, self.get_hw_device()) 103 | if (scaleFactor) != 0: 104 | ipTensor = ipTensor / scaleFactor 105 | return ipTensor 106 | 107 | def alloc_embedding_tables(self, n, m, curRankDevice, dtype): 108 | EE = nn.EmbeddingBag(n, m, mode="sum", sparse=True) 109 | 110 | W = np.random.uniform( 111 | low=-np.sqrt(1 / n), high=np.sqrt(1 / n), size=(n, m) 112 | ).astype(np.float32) 113 | # approach 1 114 | 115 | EE.weight.data = torch.tensor( 116 | W, dtype=dtype, requires_grad=True, device=curRankDevice 117 | ) 118 | return EE 119 | 120 | def alloc_empty(self, sizeArr, dtype, curRankDevice): 121 | return torch.empty(sizeArr, device=curRankDevice, dtype=dtype) 122 | 123 | def clear_memory(self, collectiveArgs): 124 | pass # torch.cuda.empty_cache() 125 | 126 | # Getting world-size and other information. 127 | def get_local_rank( 128 | self, 129 | ): 130 | return xm.get_local_ordinal() 131 | 132 | def get_local_size( 133 | self, 134 | ): 135 | return self.bootstrap_info.local_size 136 | 137 | def get_global_rank( 138 | self, 139 | ): 140 | return xm.get_ordinal() 141 | 142 | def get_world_size( 143 | self, 144 | ): 145 | return xm.xrt_world_size() 146 | 147 | def get_device( 148 | self, 149 | ): 150 | return xm.xla_device() 151 | 152 | def get_hw_device( 153 | self, 154 | ): 155 | return xm._xla_real_device(xm.xla_device()) 156 | 157 | def get_default_group(self): 158 | pass 159 | 160 | def get_groups(self): 161 | pass 162 | 163 | def tensor_list_to_numpy(self, tensorList): 164 | tensorList = torch.transpose(tensorList.view(-1, 1), 0, 1)[0] 165 | return tensorList.cpu().detach().numpy() 166 | 167 | # Init functions 168 | def __init__(self, bootstrap_info, commsParams): 169 | self.bootstrap_info = bootstrap_info 170 | self.commsParams = commsParams 171 | 172 | def initialize_backend(self, master_ip, master_port, backend="gloo"): 173 | pass 174 | 175 | def benchmark_comms(self, benchTime, commsParams): 176 | xmp.spawn( 177 | fn=benchTime, 178 | args=(commsParams, self), 179 | nprocs=self.bootstrap_info.num_tpu_cores, 180 | ) 181 | return 182 | 183 | def __del__(self): 184 | pass 185 | -------------------------------------------------------------------------------- /et_replay/comm/commsTraceParser.py: -------------------------------------------------------------------------------- 1 | # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 2 | from __future__ import annotations 3 | 4 | import json 5 | 6 | import logging 7 | 8 | import math 9 | 10 | from et_replay.comm import comms_utils 11 | from et_replay.comm.backend.base_backend import supportedP2pOps 12 | from et_replay.comm.comms_utils import commsArgs 13 | 14 | from et_replay.execution_trace import ExecutionTrace 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | def parseTrace( 20 | in_trace: list, 21 | trace_type: str, 22 | trace_file_path: str, 23 | target_rank: int, 24 | total_ranks: int, 25 | ) -> list: 26 | """ 27 | Parse trace files to be compatible with PARAM replay-mode. 28 | Currently supports: Chakra host execution trace. 29 | 30 | Args: 31 | in_trace: Trace file to be parsed. 32 | trace_type: Trace type to be parsed with 33 | trace_file_path: Path of input trace file being loaded. 34 | target_rank: The current rank of the device. 35 | total_ranks: Total number of ranks. 36 | Returns: 37 | parsed_trace: Parsed trace that is compatible with PARAM replay-mode. 38 | """ 39 | 40 | if trace_type == "et": # Execution Trace (e.g. Chakra host execution trace) 41 | parsed_trace = _parseExecutionTrace( 42 | ExecutionTrace(in_trace), target_rank, total_ranks 43 | ) 44 | else: 45 | raise ValueError( 46 | f"Specified trace type {trace_type} to {trace_file_path} is not supported. \ 47 | Please check supported types with '--help'" 48 | ) 49 | 50 | return parsed_trace 51 | 52 | 53 | def _parseExecutionTrace( 54 | in_trace: ExecutionTrace, target_rank: int, total_ranks: int 55 | ) -> list: 56 | """ 57 | Convert the Execution Trace comms metadata to the common trace format for replay. 58 | """ 59 | if in_trace.schema_pytorch() < (1, 0, 3): 60 | raise ValueError( 61 | f"Only support trace version >1.0.3, but current trace version is {in_trace.schema.split('-')[0]}" 62 | ) 63 | 64 | # pg_ranks_map: key is pg id, value is global ranks in this pg 65 | # pg_desc_map: key is pg id, value is pg desc 66 | pg_ranks_map, pg_desc_map = _parse_proc_group_info(in_trace) 67 | comms_op_list = _parse_comms_op_node( 68 | in_trace, pg_ranks_map, pg_desc_map, target_rank, total_ranks 69 | ) 70 | 71 | return comms_op_list 72 | 73 | 74 | def _parse_proc_group_info(in_trace: ExecutionTrace): 75 | pg_ranks_map = {} # {node_id : {process_group_id : [ranks] } } 76 | pg_desc_map = {} # {node_id : {process_group_id : pg_desc } 77 | pg_init_nodes = ( 78 | node for node in in_trace.nodes.values() if "process_group:init" in node.name 79 | ) 80 | for node in pg_init_nodes: 81 | # info of this node is dumped using torch.distributed.distributed_c10d._world.pg_config_info 82 | # at the start of profiling, but not callback to torch.distributed.init_process_group() 83 | # Pre-Assumption: all process groups has been created before profiling start. 84 | try: 85 | pg_objs = json.loads(node.inputs[0]) 86 | except json.decoder.JSONDecodeError: # skip if pg_config_info is truncated 87 | break 88 | 89 | pg_ranks_map[node.id] = {} 90 | pg_desc_map[node.id] = {} 91 | for pg in pg_objs: 92 | if not pg["pg_name"].isdecimal(): 93 | # TODO support local synchronization pg 94 | logger.warning( 95 | f"Process group name is {pg['pg_name']} in node {node.id}, which is not supported. Skip." 96 | ) 97 | continue 98 | (pg_id, pg_desc, ranks, group_size, group_count) = ( 99 | pg[k] 100 | for k in ["pg_name", "pg_desc", "ranks", "group_size", "group_count"] 101 | ) 102 | pg_id = int(pg_id) 103 | pg_ranks_map[node.id][pg_id] = ( 104 | ranks if len(ranks) > 0 else list(range(group_size)) 105 | # rank list is empty when all ranks are in a pg 106 | ) 107 | pg_desc_map[node.id][pg_id] = pg_desc 108 | break # only one process_group init node per trace 109 | return pg_ranks_map, pg_desc_map 110 | 111 | 112 | def _parse_comms_op_node( # noqa: C901 113 | in_trace: ExecutionTrace, 114 | pg_ranks_map: dict, 115 | pg_desc_map: dict, 116 | target_rank: int, 117 | total_ranks: int, 118 | ): 119 | comms_op_list = [] 120 | 121 | for node_id in pg_ranks_map: 122 | for pg_id, ranks in pg_ranks_map[node_id].items(): 123 | comm_args = _create_pg_init_node( 124 | node_id, pg_id, ranks, pg_desc_map[node_id][pg_id], len(ranks) 125 | ) 126 | comms_op_list.append(comm_args) 127 | 128 | pg_ranks_map_flatten = {} 129 | for _, v in pg_ranks_map.items(): 130 | pg_ranks_map_flatten.update(v) 131 | 132 | comm_nodes = ( 133 | node for node in in_trace.nodes.values() if node.name == "record_param_comms" 134 | ) 135 | is_seq_id = ( 136 | lambda x: isinstance(x, list) 137 | and len(x) == 2 138 | and isinstance(x[0], int) 139 | and isinstance(x[1], bool) 140 | ) 141 | for node in comm_nodes: 142 | # for ["wait", "barrier", "init"] ops, before having different seq_id for p2p op and non p2p op, seq_id is an integer for the first input 143 | # After having different seq_id for p2p op and non p2p op, seq_id is a list of [seq_id, isP2P] for the first input 144 | # Need to handle both cases, in the future this kind of change should have different version of schema, and we can use version to decide how to parse the trace 145 | if is_seq_id(node.inputs[0]) or isinstance(node.inputs[0], int): 146 | index_base = 0 147 | else: 148 | index_base = 1 149 | req_id = node.inputs[index_base] 150 | recorded_rank = node.inputs[index_base + 2] 151 | 152 | comm_args = commsArgs() 153 | comm_args.id = node.id 154 | comm_args.comms = comms_utils.paramToCommName( 155 | node.commArgs.collective_name.lower() 156 | ) 157 | if comm_args.comms == "init": 158 | # init node has been built 159 | continue 160 | 161 | if isinstance(req_id, int): 162 | # this is the format before having different seq_id for p2p op and non p2p op 163 | comm_args.req = (req_id, False) 164 | else: 165 | comm_args.req = req_id 166 | 167 | if node.commArgs.pg_name and node.commArgs.pg_name.isdecimal(): 168 | comm_args.pgId = int(node.commArgs.pg_name) 169 | comm_args.groupRanks = pg_ranks_map_flatten[comm_args.pgId] 170 | comm_args.worldSize = len(comm_args.groupRanks) 171 | 172 | if comm_args.comms not in ("wait", "barrier"): 173 | comm_args.inMsgSize = node.commArgs.in_msg_nelems 174 | comm_args.outMsgSize = node.commArgs.out_msg_nelems 175 | comm_args.dtype = node.commArgs.dtype.lower() 176 | 177 | # the recorded rank id in execution trace is local rank id in the process group 178 | # we need to convert it to global rank for replay, check the function broadcast() of pytorch below: 179 | # https://github.com/pytorch/pytorch/blob/6c4efd4e959017fc758fcc5dc32d8cc6a4b9164d/torch/distributed/distributed_c10d.py#L2404 180 | if comm_args.comms in supportedP2pOps: 181 | if "send" in comm_args.comms: 182 | (comm_args.src_rank, comm_args.dst_rank) = ( 183 | target_rank, 184 | comm_args.groupRanks[recorded_rank], 185 | ) 186 | elif "recv" in comm_args.comms: 187 | (comm_args.src_rank, comm_args.dst_rank) = ( 188 | comm_args.groupRanks[recorded_rank], 189 | target_rank, 190 | ) 191 | elif comm_args.comms in ["reduce", "broadcast", "gather", "scatter"]: 192 | comm_args.root = comm_args.groupRanks[recorded_rank] 193 | comm_args.groupRanks = comm_args.groupRanks 194 | 195 | if comm_args.comms == "all_to_all": 196 | # flatten each tensor and store the # of elements into split field 197 | comm_args.inSplit = [math.prod(i) for i in node.input_shapes[0]] 198 | comm_args.outSplit = [math.prod(i) for i in node.output_shapes[0]] 199 | elif comm_args.comms == "all_to_allv": 200 | if not comm_args.worldSize: 201 | # if no pg info provided, use total ranks as world size 202 | comm_args.worldSize = total_ranks 203 | comm_args.inSplit = json.loads(node.commArgs.in_split_size) 204 | comm_args.outSplit = json.loads(node.commArgs.out_split_size) 205 | 206 | comms_op_list.append(comm_args) 207 | 208 | return comms_op_list 209 | 210 | 211 | def _create_pg_init_node( 212 | node_id: int, pg_id: int, ranks: list[int], pg_desc: str, world_size: int 213 | ): 214 | comm_args = commsArgs() 215 | comm_args.id = node_id 216 | comm_args.comms = "init" 217 | comm_args.pgId = pg_id 218 | comm_args.pgDesc = pg_desc 219 | comm_args.req = -1 220 | comm_args.groupRanks = ranks 221 | comm_args.worldSize = world_size 222 | return comm_args 223 | -------------------------------------------------------------------------------- /et_replay/comm/param_profile.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from __future__ import annotations 7 | 8 | import logging 9 | import time 10 | from dataclasses import dataclass 11 | from typing import Any 12 | 13 | from torch.autograd.profiler import record_function 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class paramProfile(record_function): 19 | """Inherit from PyTorch profiler to enable autoguard profiling while measuring the time interval in PARAM""" 20 | 21 | def __init__(self, timer: paramTimer | None = None, description: str = "") -> None: 22 | super().__init__(name=description) 23 | self.description = description 24 | self.timer = timer 25 | self.start = 0.0 26 | self.end = 0.0 27 | self.intervalNS = 0.0 28 | 29 | def __enter__(self) -> paramProfile: 30 | super().__enter__() 31 | self.start = time.monotonic() 32 | return self 33 | 34 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 35 | self.end = time.monotonic() 36 | self.intervalNS = (self.end - self.start) * 1e9 # keeping time in NS 37 | # if given a valid paramTimer object, directly update the measured time interval 38 | if isinstance(self.timer, paramTimer): 39 | self.timer.incrTimeNS(self.intervalNS) 40 | logger.debug(f"{self.description} took {self.intervalNS} ns") 41 | super().__exit__(exc_type, exc_value, traceback) 42 | 43 | 44 | @dataclass 45 | class paramTimer: 46 | """ 47 | Timer for param profiler. 48 | """ 49 | 50 | elapsedTimeNS: float = 0.0 # keeping time in NS 51 | 52 | def reset(self, newTime: float = 0.0) -> None: 53 | self.elapsedTimeNS = newTime 54 | 55 | def incrTimeNS(self, timeNS: float) -> None: 56 | self.elapsedTimeNS += timeNS 57 | 58 | def getTimeUS(self) -> float: 59 | return self.elapsedTimeNS / 1e3 60 | 61 | def getTimeNS(self) -> float: 62 | return self.elapsedTimeNS 63 | -------------------------------------------------------------------------------- /et_replay/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "et_replay" 7 | version = "0.5.0" 8 | dependencies = [ 9 | "numpy", 10 | "intervaltree", 11 | "pydot", 12 | ] 13 | 14 | [tool.setuptools.package-dir] 15 | "et_replay" = "." 16 | "param_bench" = ".." 17 | 18 | [project.scripts] 19 | comm_replay = "et_replay.tools.comm_replay:main" 20 | et_replay = "et_replay.tools.et_replay:main" 21 | validate_traces = "et_replay.tools.validate_traces:main" 22 | -------------------------------------------------------------------------------- /et_replay/tests/inputs/1.0.3-chakra.0.0.4/resnet_1gpu_et.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/param/365dc0409e6f66a1823199a62553a1ed558693b6/et_replay/tests/inputs/1.0.3-chakra.0.0.4/resnet_1gpu_et.json.gz -------------------------------------------------------------------------------- /et_replay/tests/inputs/1.1.0-chakra.0.0.4/resnet_2gpu_et.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/param/365dc0409e6f66a1823199a62553a1ed558693b6/et_replay/tests/inputs/1.1.0-chakra.0.0.4/resnet_2gpu_et.json.gz -------------------------------------------------------------------------------- /et_replay/tests/inputs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/param/365dc0409e6f66a1823199a62553a1ed558693b6/et_replay/tests/inputs/__init__.py -------------------------------------------------------------------------------- /et_replay/tests/inputs/dlrm_kineto.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/param/365dc0409e6f66a1823199a62553a1ed558693b6/et_replay/tests/inputs/dlrm_kineto.tar.gz -------------------------------------------------------------------------------- /et_replay/tests/inputs/dlrm_pytorch_et.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/param/365dc0409e6f66a1823199a62553a1ed558693b6/et_replay/tests/inputs/dlrm_pytorch_et.tar.gz -------------------------------------------------------------------------------- /et_replay/tests/inputs/linear_et.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/param/365dc0409e6f66a1823199a62553a1ed558693b6/et_replay/tests/inputs/linear_et.json.gz -------------------------------------------------------------------------------- /et_replay/tests/inputs/linear_kineto.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/param/365dc0409e6f66a1823199a62553a1ed558693b6/et_replay/tests/inputs/linear_kineto.json.gz -------------------------------------------------------------------------------- /et_replay/tests/inputs/resnet_et.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/param/365dc0409e6f66a1823199a62553a1ed558693b6/et_replay/tests/inputs/resnet_et.json.gz -------------------------------------------------------------------------------- /et_replay/tests/inputs/resnet_kineto.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/param/365dc0409e6f66a1823199a62553a1ed558693b6/et_replay/tests/inputs/resnet_kineto.json.gz -------------------------------------------------------------------------------- /et_replay/tests/test_execution_trace.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import json 3 | import os 4 | import unittest 5 | 6 | from et_replay import ExecutionTrace 7 | from et_replay.tools.validate_trace import TraceValidator 8 | 9 | CURR_DIR = os.path.dirname(os.path.realpath(__file__)) 10 | 11 | 12 | class TestTraceLoadAndValidate(unittest.TestCase): 13 | def setUp(self): 14 | self.trace_base = os.path.join(CURR_DIR, "inputs") 15 | 16 | def _test_and_validate_trace(self, trace_file): 17 | with ( 18 | gzip.open(trace_file, "rb") 19 | if trace_file.endswith("gz") 20 | else open(trace_file) 21 | ) as execution_data: 22 | execution_trace: ExecutionTrace = ExecutionTrace(json.load(execution_data)) 23 | t = TraceValidator(execution_trace) 24 | self.assertTrue(t.validate()) 25 | return t, execution_trace 26 | 27 | def test_trace_load_resnet_1gpu_ptorch_1_0_3(self): 28 | et_file = os.path.join( 29 | self.trace_base, "1.0.3-chakra.0.0.4/resnet_1gpu_et.json.gz" 30 | ) 31 | t, et = self._test_and_validate_trace(et_file) 32 | self.assertGreater(t.num_ops(), 1000) 33 | self.assertEqual(t.num_comm_ops(), 27) 34 | self.assertEqual(t.num_triton_ops(), 0) 35 | 36 | def test_trace_load_resnet_2gpu_ptorch_1_1_0(self): 37 | et_file = os.path.join( 38 | self.trace_base, "1.1.0-chakra.0.0.4/resnet_2gpu_et.json.gz" 39 | ) 40 | t, et = self._test_and_validate_trace(et_file) 41 | self.assertGreater(t.num_ops(), 1000) 42 | self.assertEqual(t.num_comm_ops(), 27) 43 | self.assertEqual(t.num_triton_ops(), 0) 44 | 45 | 46 | if __name__ == "__main__": 47 | unittest.main() 48 | -------------------------------------------------------------------------------- /et_replay/tools/validate_trace.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import gzip 4 | import json 5 | 6 | from et_replay.execution_trace import ExecutionTrace 7 | 8 | 9 | class TraceValidator: 10 | def __init__(self, execution_trace: ExecutionTrace): 11 | self.et = execution_trace 12 | 13 | def _ops(self): 14 | return (n for n in self.et.nodes.values() if n.is_op()) 15 | 16 | def _validate_ops(self) -> bool: 17 | """Make sure the pytorch operators are valid""" 18 | ops = self._ops() 19 | for op in ops: 20 | if op.name == "": 21 | print(f"op should have valid name, node id = {op.id}") 22 | 23 | # if len(list(op.get_outputs())) + len(list(op.get_inputs())) == 0: 24 | # print(f"op should have outputs or inputs, node = {op.name}") 25 | # FIXME see "autograd::engine::evaluate_function: DivBackward1" 26 | # currently let's skip this 27 | # return False 28 | return True 29 | 30 | def _validate_tree(self) -> bool: 31 | """TBD validate that the generated datastructure is a tree 32 | with parent/child relationship. We can use pydot or networkx libs for this 33 | """ 34 | return True 35 | 36 | def _validate_param_comms(self) -> bool: 37 | """Check if param comms has correct attributes""" 38 | 39 | if self.et.schema_pytorch() < (1, 0, 2): 40 | return True 41 | 42 | def check_comms_node_pre_1_1_0(n) -> bool: 43 | """Roughly based on commsTraceParser""" 44 | # https://github.com/facebookresearch/param/blob/main/train/comms/pt/commsTraceParser.py#L256 45 | 46 | has_pg_id = False 47 | # Slightly hacky but find a argument with tuple type 48 | for arg in n.get_inputs(): 49 | if arg[0] == "Tuple[String,String]": 50 | print(f" {n.name}, process group args = {arg}") 51 | has_pg_id = True 52 | return has_pg_id 53 | 54 | def check_comms_node_1_1_0(n) -> bool: 55 | """New elements are added as per 56 | https://github.com/pytorch/pytorch/issues/124674 57 | """ 58 | # TODO check for node.commArgs dataclass 59 | print(n.commArgs) 60 | return True 61 | 62 | check_comms_node = ( 63 | check_comms_node_1_1_0 64 | if self.et.schema_pytorch() >= (1, 1, 0) 65 | else check_comms_node_pre_1_1_0 66 | ) 67 | 68 | return all( 69 | check_comms_node(n) 70 | for n in self.et.nodes.values() 71 | if n.is_op() and n.name == "record_param_comms" 72 | ) 73 | 74 | def _validate_triton(self) -> bool: 75 | """Make sure triton kernels have correct values 76 | TODO update for checking if kernel files are captured. 77 | """ 78 | return True 79 | 80 | def validate(self) -> bool: 81 | return all( 82 | [ 83 | self._validate_ops(), 84 | self._validate_tree(), 85 | self._validate_param_comms(), 86 | self._validate_triton(), 87 | ] 88 | ) 89 | 90 | def num_ops(self) -> int: 91 | return len(list(self._ops())) 92 | 93 | def num_comm_ops(self) -> int: 94 | return sum(1 for op in self._ops() if op.name == "record_param_comms") 95 | 96 | def num_triton_ops(self) -> int: 97 | return sum(1 for op in self._ops() if "triton" in op.name) 98 | 99 | 100 | def main(): 101 | import sys 102 | 103 | execution_json = sys.argv[1] 104 | 105 | with ( 106 | gzip.open(execution_json, "rb") 107 | if execution_json.endswith("gz") 108 | else open(execution_json) 109 | ) as execution_data: 110 | execution_trace: ExecutionTrace = ExecutionTrace(json.load(execution_data)) 111 | t = TraceValidator(execution_trace) 112 | print( 113 | f"num ops = {t.num_ops()}, num comms = {t.num_comm_ops()}, " 114 | f"num triton ops = {t.num_triton_ops()}" 115 | ) 116 | print("Trace validation result = ", t.validate()) 117 | 118 | 119 | if __name__ == "__main__": 120 | main() 121 | -------------------------------------------------------------------------------- /et_replay/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import gzip 3 | import json 4 | import logging 5 | import os 6 | import uuid 7 | from typing import Any 8 | 9 | from et_replay.execution_trace import ExecutionTrace 10 | 11 | 12 | def get_tmp_trace_filename() -> str: 13 | """Generate a temporary filename using the current date, a UUID, and the process ID.""" 14 | trace_fn = ( 15 | "tmp_" 16 | + datetime.datetime.today().strftime("%Y%m%d") 17 | + "_" 18 | + uuid.uuid4().hex[:7] 19 | + "_" 20 | + str(os.getpid()) 21 | + ".json" 22 | ) 23 | return trace_fn 24 | 25 | 26 | def trace_handler(prof: Any) -> None: 27 | """Export a chrome trace""" 28 | fn = get_tmp_trace_filename() 29 | prof.export_chrome_trace("/tmp/" + fn) 30 | logging.warning(f"Chrome profile trace written to /tmp/{fn}") 31 | 32 | 33 | def load_execution_trace_file(et_file_path: str) -> ExecutionTrace: 34 | """Loads Execution Trace from json file and parses it.""" 35 | data = read_dictionary_from_json_file(et_file_path) 36 | return ExecutionTrace(data) 37 | 38 | 39 | def read_dictionary_from_json_file(file_path: str) -> dict[Any, Any]: 40 | """Read a json file and return it as a dictionary.""" 41 | with ( 42 | gzip.open(file_path, "rb") if file_path.endswith("gz") else open(file_path) 43 | ) as f: 44 | return json.load(f) 45 | 46 | 47 | def write_dictionary_to_json_file(file_path: str, data: dict[Any, Any]) -> None: 48 | """Write input dictionary to a json file.""" 49 | if file_path.endswith("gz"): 50 | with gzip.open(file_path, "w") as f: 51 | f.write(json.dumps(data, indent=4).encode("utf-8")) 52 | else: 53 | with open(file_path, "w") as f: 54 | json.dump(data, f, indent=4) 55 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | future 3 | numpy 4 | pydot 5 | -------------------------------------------------------------------------------- /torchx_run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | # Entry point for torchx scripts 3 | 4 | set -eE 5 | 6 | # shellcheck disable=SC1091 7 | 8 | export LD_PRELOAD="${PRELOAD_PATH:=/usr/local/fbcode/platform010/lib/libcuda.so:/usr/local/fbcode/platform010/lib/libnvidia-ml.so}" 9 | export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:${CONDA_DIR}/lib" 10 | export PYTHONPATH="${PYTHONPATH}:${TORCHX_RUN_PYTHONPATH}" 11 | 12 | # shellcheck disable=SC1091 13 | source "${CONDA_DIR}/bin/activate" 14 | cd "${WORKSPACE_DIR}" 15 | python3 -X faulthandler "$@" 16 | -------------------------------------------------------------------------------- /train/comms/pt/README.md: -------------------------------------------------------------------------------- 1 | # PARAM benchmark - Communication benchmarks 2 | 3 | PARAM-Comms is an effort to develop a unified benchmarking framework to 4 | characterize training platform backends. Currently, the benchmark supports 5 | Pytorch Distributed and PyTorch-XLA backends. 6 | 7 | The PARAM-Comms benchmark offers a single point solution to perform both top-down 8 | (DLRM application) and bottoms-up (collectives) operations for any given 9 | communication backend. 10 | 11 | The Collective-Comms benchmark (`comms.py`) is designed similar to nccl-tests 12 | for evaluating collective operations, such as All-reduce and All-to-all, through PyTorch backends. 13 | The DLRM-Comms benchmark (`dlrm.py`) is similar to the open-source DLRM benchmark except it 14 | only implements communication primitives. 15 | The Trace Replay benchmark (`commsTraceReplay.py`) is designed to replay the communication patterns captured 16 | from any distributed PyTorch workloads. 17 | 18 | ## Usage: 19 | 20 | ### Collective-Comms benchmark (`comms.py`) 21 | ```bash 22 | mpirun -np -N --hostfile ./comms.py \ 23 | --master-ip 127.0.0.1 24 | --b \ 25 | --e \ 26 | --n \ 27 | --f \ 28 | --z \ 29 | --collective 30 | ``` 31 | Example: 32 | ```bash 33 | mpirun -np 16 -N 8 --hostfile ./hfile ./comms.py --master-ip $(head -n 1 ./hfile.txt) --b 8 --e 256M --n 100 \ 34 | --f 2 --z 1 --collective all_to_all --backend nccl --device cuda --log INFO 35 | ``` 36 | 37 | ### DLRM-Comms benchmark (`dlrm.py`) 38 | ```bash 39 | mpirun -np -N --hostfile ./dlrm.py \ 40 | --master-ip 41 | --arch-sparse-feature-size \ 42 | --arch-embedding-size \ 43 | --arch-mlp-bot \ 44 | --arch-mlp-top \ 45 | --mini-batch-size \ 46 | --num-batches 47 | ``` 48 | Example: 49 | ```bash 50 | mpirun -np 16 -N 8 --hostfile ./hfile ./dlrm.py --master-ip $(head -n 1 ./hfile.txt) --mini-batch-size 32 \ 51 | --num-batches 100 \ 52 | --arch-mlp-bot 1024-256 \ 53 | --arch-sparse-feature-size 64 \ 54 | --arch-embedding-size "10000-10000-10000-10000-10000-10000-10000-10000-10000-10000-10000-10000-10000-10000-10000-10000" 55 | ``` 56 | 57 | ### Trace Replay benchmark (`commsTraceReplay.py`) 58 | ```bash 59 | mpirun -np -N --hostfile ./commsTraceReplay.py \ 60 | --master-ip 127.0.0.1 --trace-path /path/to/traces --dry-run 61 | ``` 62 | Example: 63 | ```bash 64 | mpirun -np 16 -N 8 --hostfile ./hfile ./commsTraceReplay.py --master-ip $(head -n 1 ./hfile.txt) \ 65 | --backend nccl --device cuda \ 66 | --trace-path /path/to/commTraces 67 | ``` 68 | Note that there should be one trace file (in JSON format) per rank. 69 | -------------------------------------------------------------------------------- /train/comms/pt/logger_utils.py: -------------------------------------------------------------------------------- 1 | # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import logging 7 | from dataclasses import dataclass 8 | from enum import Enum 9 | from typing import abstractmethod, Dict, Optional 10 | 11 | from param_bench.train.comms.pt.pytorch_backend_utils import backendFunctions 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class benchType(Enum): 17 | Collective = 0 18 | Pt2Pt = 1 19 | QuantCollective = 2 20 | 21 | 22 | @dataclass 23 | class commsPerfMetrics: 24 | """ 25 | Base Class for storing performance metrics for communication op. 26 | """ 27 | 28 | commsOp: str = None 29 | Datatype: str = None 30 | BenchCommsType: int = None 31 | Backend: str = None 32 | Tags: str = "" 33 | InputSize: float = 0.0 34 | OutputSize: float = 0.0 35 | NumElements: int = 0 36 | NumElements_pair: int = 0 37 | 38 | 39 | @dataclass 40 | class commsQuantCollPerfMetrics(commsPerfMetrics): 41 | """ 42 | Class for storing performance metrics for a collective with quentization enabled. 43 | """ 44 | 45 | p95_latency_us: float = 0.0 46 | quant_p95_latency_us: float = 0.0 47 | dequant_p95_latency_us: float = 0.0 48 | quant_comms_p95_latency_us: float = 0.0 49 | TFLOPs: float | None = 0.0 50 | 51 | def __post_init__(self): 52 | self.BenchCommsType = benchType.QuantCollective 53 | 54 | 55 | @dataclass 56 | class commsCollPerfMetrics(commsPerfMetrics): 57 | """ 58 | Class for storing performance metrics for a collective. 59 | """ 60 | 61 | p50_latency_us: float = 0.0 62 | p75_latency_us: float = 0.0 63 | p95_latency_us: float = 0.0 64 | min_latency_us: float = 0.0 65 | max_latency_us: float = 0.0 66 | AlgoBW_GBs: float = 0.0 67 | BusBW_GBs: float = 0.0 68 | TFLOPs: float | None = 0.0 69 | 70 | def __post_init__(self): 71 | self.BenchCommsType = benchType.Collective 72 | 73 | 74 | @dataclass 75 | class commsPt2PtPerfMetrics(commsPerfMetrics): 76 | """ 77 | Class for storing performance metrics for a point-to-point. 78 | """ 79 | 80 | p50_latency_us: float = 0.0 81 | p75_latency_us: float = 0.0 82 | p95_latency_us: float = 0.0 83 | AvgUniBW_GBs: float = 0.0 84 | AvgBiBW_GBs: float = 0.0 85 | TotalUniBW_GBs: float = 0.0 86 | TotalBiBW_GBs: float = 0.0 87 | 88 | def __post_init__(self): 89 | self.BenchCommsType = benchType.Pt2Pt 90 | 91 | 92 | class commsPerfLogger: 93 | """ 94 | Helper class for logging performance metrics. 95 | """ 96 | 97 | def __init__(self, loggerName: str): 98 | self.name = loggerName 99 | 100 | @abstractmethod 101 | def logPerf( 102 | self, 103 | benchmarkName: str, 104 | metrics: commsPerfMetrics, 105 | backendFuncs: backendFunctions, 106 | **kwargs, 107 | ): 108 | """ 109 | Log performance metrics for the collective. 110 | Args: 111 | benchmarkName: Name of benchmark, e.g., "comms" or "replay". 112 | metrics: Performance metrics for this collective. 113 | backendFuncs: Backend function/object used in this benchmark. 114 | Returns: 115 | None 116 | """ 117 | pass 118 | 119 | 120 | customized_perf_loggers: dict[str, commsPerfLogger] = {} 121 | 122 | 123 | def register_perf_logger( 124 | name: str, 125 | func: commsPerfLogger, 126 | ) -> None: 127 | global customized_perf_loggers 128 | customized_perf_loggers[name] = func 129 | logger.info(f"Registered custom perf logger {name}") 130 | -------------------------------------------------------------------------------- /train/comms/pt/matmul_perf_model.py: -------------------------------------------------------------------------------- 1 | # Source: https://github.com/triton-lang/kernels/blob/main/kernels/matmul_perf_model.py 2 | 3 | # This file is taken from the upstream triton-lang/kernels repo. 4 | # Currently that repo does not have a license file, so disabling 5 | # the license lint for now: 6 | # @lint-ignore-every LICENSELINT 7 | 8 | # flake8: noqa 9 | # pyre-ignore-all-errors 10 | import functools 11 | import heapq 12 | 13 | import torch 14 | 15 | from triton import cdiv 16 | from triton.runtime import driver 17 | from triton.testing import ( 18 | get_dram_gbps, 19 | get_max_simd_tflops, 20 | get_max_tensorcore_tflops, 21 | nvsmi, 22 | ) 23 | 24 | 25 | @functools.lru_cache 26 | def get_clock_rate_in_khz(): 27 | try: 28 | return nvsmi(["clocks.max.sm"])[0] * 1e3 29 | except FileNotFoundError: 30 | import pynvml 31 | 32 | pynvml.nvmlInit() 33 | handle = pynvml.nvmlDeviceGetHandleByIndex(0) 34 | return pynvml.nvmlDeviceGetMaxClockInfo(handle, pynvml.NVML_CLOCK_SM) * 1e3 35 | 36 | 37 | def get_tensorcore_tflops(device, num_ctas, num_warps, dtype): 38 | """return compute throughput in TOPS""" 39 | total_warps = num_ctas * min(num_warps, 4) 40 | num_subcores = ( 41 | driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 42 | ) # on recent GPUs 43 | tflops = ( 44 | min(num_subcores, total_warps) 45 | / num_subcores 46 | * get_max_tensorcore_tflops(dtype, get_clock_rate_in_khz(), device) 47 | ) 48 | return tflops 49 | 50 | 51 | def get_simd_tflops(device, num_ctas, num_warps, dtype): 52 | """return compute throughput in TOPS""" 53 | total_warps = num_ctas * min(num_warps, 4) 54 | num_subcores = ( 55 | driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 56 | ) # on recent GPUs 57 | tflops = ( 58 | min(num_subcores, total_warps) 59 | / num_subcores 60 | * get_max_simd_tflops(dtype, get_clock_rate_in_khz(), device) 61 | ) 62 | return tflops 63 | 64 | 65 | def get_tflops(device, num_ctas, num_warps, dtype): 66 | capability = torch.cuda.get_device_capability(device) 67 | if capability[0] < 8 and dtype == torch.float32: 68 | return get_simd_tflops(device, num_ctas, num_warps, dtype) 69 | return get_tensorcore_tflops(device, num_ctas, num_warps, dtype) 70 | 71 | 72 | def estimate_matmul_time( 73 | # backend, device, 74 | num_warps, 75 | num_stages, # 76 | A, 77 | B, 78 | C, # 79 | M, 80 | N, 81 | K, # 82 | BLOCK_M, 83 | BLOCK_N, 84 | BLOCK_K, 85 | SPLIT_K, # 86 | debug=False, 87 | **kwargs, # 88 | ): 89 | """return estimated running time in ms 90 | = max(compute, loading) + store""" 91 | device = torch.cuda.current_device() 92 | dtype = A.dtype 93 | dtsize = A.element_size() 94 | 95 | num_cta_m = cdiv(M, BLOCK_M) 96 | num_cta_n = cdiv(N, BLOCK_N) 97 | num_cta_k = SPLIT_K 98 | num_ctas = num_cta_m * num_cta_n * num_cta_k 99 | 100 | # If the input is smaller than the block size 101 | M, N = max(M, BLOCK_M), max(N, BLOCK_N) 102 | 103 | # time to compute 104 | total_ops = 2 * M * N * K / (1024 * 1024 * 1024) # GOPS 105 | tput = get_tflops(device, num_ctas, num_warps, dtype) 106 | compute_ms = total_ops / tput 107 | 108 | # time to load data 109 | num_sm = driver.active.utils.get_device_properties(device)["multiprocessor_count"] 110 | active_cta_ratio = min(1, num_ctas / num_sm) 111 | active_cta_ratio_bw1 = min( 112 | 1, num_ctas / 32 113 | ) # 32 active ctas are enough to saturate 114 | active_cta_ratio_bw2 = max( 115 | min(1, (num_ctas - 32) / (108 - 32)), 0 116 | ) # 32-108, remaining 5% 117 | dram_bw = get_dram_gbps(device) * ( 118 | active_cta_ratio_bw1 * 0.95 + active_cta_ratio_bw2 * 0.05 119 | ) # in GB/s 120 | l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?) 121 | # assume 80% of (following) loads are in L2 cache 122 | load_a_dram = M * K * dtsize * (1 + 0.2 * (num_cta_n - 1)) 123 | load_a_l2 = M * K * dtsize * 0.8 * (num_cta_n - 1) 124 | load_b_dram = N * K * dtsize * (1 + 0.2 * (num_cta_m - 1)) 125 | load_b_l2 = N * K * dtsize * 0.8 * (num_cta_m - 1) 126 | # total 127 | total_dram = (load_a_dram + load_b_dram) / (1024 * 1024) # MB 128 | total_l2 = (load_a_l2 + load_b_l2) / (1024 * 1024) 129 | # loading time in ms 130 | load_ms = total_dram / dram_bw + total_l2 / l2_bw 131 | 132 | # estimate storing time 133 | store_bw = dram_bw * 0.6 # :o 134 | store_c_dram = M * N * dtsize * SPLIT_K / (1024 * 1024) # MB 135 | if SPLIT_K == 1: 136 | store_ms = store_c_dram / store_bw 137 | else: 138 | reduce_bw = store_bw 139 | store_ms = store_c_dram / reduce_bw 140 | # c.zero_() 141 | zero_ms = M * N * 2 / (1024 * 1024) / store_bw 142 | store_ms += zero_ms 143 | 144 | total_time_ms = max(compute_ms, load_ms) + store_ms 145 | if debug: 146 | print( 147 | f"Total time: {total_time_ms}ms, compute time: {compute_ms}ms, " 148 | f"loading time: {load_ms}ms, store time: {store_ms}ms, " 149 | f"Activate CTAs: {active_cta_ratio*100}%" 150 | ) 151 | return total_time_ms 152 | 153 | 154 | def early_config_prune(configs, named_args, **kwargs): 155 | device = torch.cuda.current_device() 156 | capability = torch.cuda.get_device_capability() 157 | # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages 158 | dtsize = named_args["A"].element_size() 159 | dtype = named_args["A"].dtype 160 | 161 | # 1. make sure we have enough smem 162 | pruned_configs = [] 163 | for config in configs: 164 | kw = config.kwargs 165 | BLOCK_M, BLOCK_N, BLOCK_K, num_stages = ( 166 | kw["BLOCK_M"], 167 | kw["BLOCK_N"], 168 | kw["BLOCK_K"], 169 | config.num_stages, 170 | ) 171 | 172 | max_shared_memory = driver.active.utils.get_device_properties(device)[ 173 | "max_shared_mem" 174 | ] 175 | required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize 176 | if required_shared_memory <= max_shared_memory: 177 | pruned_configs.append(config) 178 | configs = pruned_configs 179 | 180 | # Some dtypes do not allow atomic_add 181 | if dtype not in [torch.float16, torch.float32]: 182 | configs = [config for config in configs if config.kwargs["SPLIT_K"] == 1] 183 | 184 | # group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps) 185 | configs_map = {} 186 | for config in configs: 187 | kw = config.kwargs 188 | BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages = ( 189 | kw["BLOCK_M"], 190 | kw["BLOCK_N"], 191 | kw["BLOCK_K"], 192 | kw["SPLIT_K"], 193 | config.num_warps, 194 | config.num_stages, 195 | ) 196 | 197 | key = (BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps) 198 | if key in configs_map: 199 | configs_map[key].append((config, num_stages)) 200 | else: 201 | configs_map[key] = [(config, num_stages)] 202 | 203 | pruned_configs = [] 204 | for k, v in configs_map.items(): 205 | BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k 206 | if capability[0] >= 8: 207 | # compute cycles (only works for ampere GPUs) 208 | mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16 * 8 * 16) 209 | mma_cycles = mmas / min(4, num_warps) * 8 210 | 211 | ldgsts_latency = 300 # Does this matter? 212 | optimal_num_stages = ldgsts_latency / mma_cycles 213 | 214 | # nearest stages, prefer large #stages 215 | nearest = heapq.nsmallest( 216 | 2, 217 | v, 218 | key=lambda x: ( 219 | 10 + abs(x[1] - optimal_num_stages) 220 | if (x[1] - optimal_num_stages) < 0 221 | else x[1] - optimal_num_stages 222 | ), 223 | ) 224 | 225 | for n in nearest: 226 | pruned_configs.append(n[0]) 227 | else: # Volta & Turing only supports num_stages <= 2 228 | random_config = v[0][0] 229 | random_config.num_stages = 2 230 | pruned_configs.append(random_config) 231 | return pruned_configs 232 | -------------------------------------------------------------------------------- /train/comms/pt/param_profile.py: -------------------------------------------------------------------------------- 1 | # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from __future__ import annotations 7 | 8 | import logging 9 | import time 10 | from dataclasses import dataclass 11 | from typing import Any 12 | 13 | from torch.autograd.profiler import record_function 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class paramProfile(record_function): 19 | """Inherit from PyTorch profiler to enable autoguard profiling while measuring the time interval in PARAM""" 20 | 21 | def __init__(self, timer: paramTimer = None, description: str = "") -> None: 22 | self.description = description 23 | self.timer = timer 24 | super().__init__(name=description) 25 | 26 | def __enter__(self) -> paramProfile: 27 | super().__enter__() 28 | self.start = time.monotonic() 29 | return self 30 | 31 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 32 | self.end = time.monotonic() 33 | self.intervalNS = (self.end - self.start) * 1e9 # keeping time in NS 34 | # if given a valid paramTimer object, directly update the measured time interval 35 | if isinstance(self.timer, paramTimer): 36 | self.timer.incrTimeNS(self.intervalNS) 37 | logger.debug(f"{self.description} took {self.intervalNS} ns") 38 | super().__exit__(exc_type, exc_value, traceback) 39 | 40 | 41 | @dataclass 42 | class paramTimer: 43 | """ 44 | Timer for param profiler. 45 | """ 46 | 47 | elapsedTimeNS: float = 0.0 # keeping time in NS 48 | 49 | def reset(self, newTime: float = 0.0) -> None: 50 | self.elapsedTimeNS = newTime 51 | 52 | def incrTimeNS(self, timeNS: float) -> None: 53 | self.elapsedTimeNS += timeNS 54 | 55 | def getTimeUS(self) -> float: 56 | return self.elapsedTimeNS / 1e3 57 | 58 | def getTimeNS(self) -> float: 59 | return self.elapsedTimeNS 60 | -------------------------------------------------------------------------------- /train/comms/pt/pytorch_tpu_backend.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 3 | import os 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch_xla.core.xla_model as xm 9 | import torch_xla.distributed.xla_multiprocessing as xmp 10 | from comms_utils import backendFunctions 11 | 12 | 13 | class PyTorchTPUBackend(backendFunctions): 14 | def sayHello(self): 15 | myhost = os.uname()[1] 16 | device = self.get_device() 17 | hw_device = self.get_hw_device() 18 | global_rank = self.get_global_rank() 19 | local_rank = self.get_local_rank() 20 | world_size = self.get_world_size() 21 | master_ip = self.bootstrap_info.master_ip 22 | print( 23 | "\tRunning on host: %s g-rank: %d, l-rank: %s world_size: %d master_ip: %s device: %s (%s)" 24 | % ( 25 | myhost, 26 | global_rank, 27 | local_rank, 28 | world_size, 29 | master_ip, 30 | device, 31 | hw_device, 32 | ) 33 | ) 34 | 35 | # Collectives 36 | def all_reduce(self, collectiveArgs, retFlag=False): 37 | retObj = xm.all_reduce(collectiveArgs.op, [collectiveArgs.ipTensor]) 38 | if collectiveArgs.asyncOp: 39 | collectiveArgs.waitObj.append(retObj) 40 | if retFlag: 41 | return retObj 42 | 43 | def reduce(self, collectiveArgs, retFlag=False): 44 | raise NotImplementedError("Func reduce: not implemented yet on TPU") 45 | 46 | def all_to_all(self, collectiveArgs, retFlag=False): 47 | retObj = xm.all_to_all(collectiveArgs.ipTensor, 0, 0, collectiveArgs.world_size) 48 | collectiveArgs.opTensor = retObj 49 | if collectiveArgs.asyncOp: 50 | collectiveArgs.waitObj.append(retObj) 51 | if retFlag: 52 | return retObj 53 | 54 | def all_to_allv(self, collectiveArgs, retFlag=False): 55 | raise NotImplementedError("Func all_to_allv: not implemented yet on TPU") 56 | 57 | def all_gather(self, collectiveArgs, retFlag=False): 58 | retObj = xm.all_gather(collectiveArgs.ipTensor, dim=0) 59 | collectiveArgs.opTensor = retObj 60 | if collectiveArgs.asyncOp: 61 | collectiveArgs.waitObj.append(retObj) 62 | if retFlag: 63 | return retObj 64 | 65 | def complete_accel_ops(self, collectiveArgs): 66 | xm.mark_step() 67 | 68 | def get_reduce_op(self, opName): 69 | if opName == "sum": 70 | return xm.REDUCE_SUM 71 | elif opName == "max": 72 | return xm.REDUCE_MAX 73 | else: 74 | return xm.REDUCE_SUM 75 | 76 | def barrier(self, collectiveArgs, name="world"): 77 | xm.rendezvous(name) 78 | 79 | # Compute functions 80 | def compute_mm(self, collectiveArgs): 81 | self.gemm(collectiveArgs) 82 | 83 | def gemm(self, collectiveArgs): 84 | collectiveArgs.MMout = torch.mm(collectiveArgs.MMin1, collectiveArgs.MMin2) 85 | 86 | # Memory related 87 | def get_mem_size(self, collectiveArgs): 88 | return ( 89 | collectiveArgs.ipTensor.nelement() * collectiveArgs.ipTensor.element_size() 90 | ) 91 | 92 | def alloc_random(self, sizeArr, curRankDevice, dtype, scaleFactor=1.0): 93 | if dtype in (torch.int32, torch.long): 94 | ipTensor = torch.randint( 95 | 0, 1000, sizeArr, device=curRankDevice, dtype=dtype 96 | ) 97 | else: 98 | ipTensor = torch.rand(sizeArr, device=curRankDevice, dtype=dtype) 99 | # ipTensor = torch.full( 100 | # sizeArr, self.get_global_rank(), device=curRankDevice, dtype=dtype 101 | # ) 102 | # print("IP: ", ipTensor, self.get_hw_device()) 103 | if (scaleFactor) != 0: 104 | ipTensor = ipTensor / scaleFactor 105 | return ipTensor 106 | 107 | def alloc_embedding_tables(self, n, m, curRankDevice, dtype): 108 | EE = nn.EmbeddingBag(n, m, mode="sum", sparse=True) 109 | 110 | W = np.random.uniform( 111 | low=-np.sqrt(1 / n), high=np.sqrt(1 / n), size=(n, m) 112 | ).astype(np.float32) 113 | # approach 1 114 | 115 | EE.weight.data = torch.tensor( 116 | W, dtype=dtype, requires_grad=True, device=curRankDevice 117 | ) 118 | return EE 119 | 120 | def alloc_empty(self, sizeArr, dtype, curRankDevice): 121 | return torch.empty(sizeArr, device=curRankDevice, dtype=dtype) 122 | 123 | def clear_memory(self, collectiveArgs): 124 | pass # torch.cuda.empty_cache() 125 | 126 | # Getting world-size and other information. 127 | def get_local_rank( 128 | self, 129 | ): 130 | return xm.get_local_ordinal() 131 | 132 | def get_local_size( 133 | self, 134 | ): 135 | return self.bootstrap_info.local_size 136 | 137 | def get_global_rank( 138 | self, 139 | ): 140 | return xm.get_ordinal() 141 | 142 | def get_world_size( 143 | self, 144 | ): 145 | return xm.xrt_world_size() 146 | 147 | def get_device( 148 | self, 149 | ): 150 | return xm.xla_device() 151 | 152 | def get_hw_device( 153 | self, 154 | ): 155 | return xm._xla_real_device(xm.xla_device()) 156 | 157 | def get_default_group(self): 158 | pass 159 | 160 | def get_groups(self): 161 | pass 162 | 163 | def get_num_pgs(self): 164 | pass 165 | 166 | def tensor_list_to_numpy(self, tensorList): 167 | tensorList = torch.transpose(tensorList.view(-1, 1), 0, 1)[0] 168 | return tensorList.cpu().detach().numpy() 169 | 170 | # Init functions 171 | def __init__(self, bootstrap_info, commsParams): 172 | self.bootstrap_info = bootstrap_info 173 | self.commsParams = commsParams 174 | 175 | def initialize_backend( 176 | self, master_ip, master_port, backend="gloo", eager_mode=False 177 | ): 178 | pass 179 | 180 | def benchmark_comms(self, benchTime, commsParams): 181 | xmp.spawn( 182 | fn=benchTime, 183 | args=(commsParams, self), 184 | nprocs=self.bootstrap_info.num_tpu_cores, 185 | ) 186 | return 187 | 188 | def __del__(self): 189 | pass 190 | -------------------------------------------------------------------------------- /train/comms/pt/setup.py: -------------------------------------------------------------------------------- 1 | # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 2 | from setuptools import setup 3 | 4 | 5 | def main(): 6 | package_base = "param_bench.train.comms.pt" 7 | 8 | # List the packages and their dir mapping: 9 | # "install_destination_package_path": "source_dir_path" 10 | package_dir_map = { 11 | f"{package_base}": ".", 12 | } 13 | 14 | packages = list(package_dir_map) 15 | 16 | setup( 17 | name="parambench-train-comms", 18 | python_requires=">=3.8", 19 | author="Louis Feng", 20 | author_email="lofe@fb.com", 21 | url="https://github.com/facebookresearch/param", 22 | packages=packages, 23 | package_dir=package_dir_map, 24 | ) 25 | 26 | 27 | if __name__ == "__main__": 28 | main() 29 | -------------------------------------------------------------------------------- /train/comms/pt/tests/mocks/backend_mock.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class MockBackendFunction: # Mock backend function 5 | # TODO: Add configurable options. 6 | def __init__(self): 7 | self.collectiveFunc = { 8 | "all_to_all": self.all_to_all, 9 | "all_to_allv": self.all_to_allv, 10 | "all_reduce": self.all_reduce, 11 | "broadcast": self.broadcast, 12 | "all_gather": self.all_gather, 13 | "reduce": self.reduce, 14 | "barrier": self.barrier, 15 | "recv": self.recv, 16 | "noop": self.noop, 17 | } 18 | 19 | self.device = "cpu" 20 | self.world_size = 1 21 | self.local_rank = 0 22 | self.global_rank = 0 23 | self.group = "default" 24 | 25 | def noop(self, collectiveArgs=None, retFlag=False, pair=False): 26 | """no-op for the case we want to skip comms/compute""" 27 | pass 28 | 29 | def sayHello(self, global_rank, local_rank, world_size, master_ip): 30 | pass 31 | 32 | # Collectives 33 | def all_gather(self, collectiveArgs, retFlag=False): 34 | self.mock_collective(collectiveArgs) 35 | 36 | def all_reduce(self, collectiveArgs, retFlag=False): 37 | self.mock_collective(collectiveArgs) 38 | 39 | def broadcast(self, collectiveArgs, retFlag=False): 40 | self.mock_collective(collectiveArgs) 41 | 42 | def reduce(self, collectiveArgs, retFlag=False): 43 | self.mock_collective(collectiveArgs) 44 | 45 | def all_to_all(self, collectiveArgs, retFlag=False): 46 | self.mock_collective(collectiveArgs) 47 | 48 | def all_to_allv(self, collectiveArgs, retFlag=False): 49 | self.mock_collective(collectiveArgs) 50 | 51 | def recv(self, collectiveArgs, retFlag=False): 52 | self.mock_collective(collectiveArgs) 53 | 54 | def complete_accel_ops(self, collectiveArgs, devSync=False): 55 | self.mock_collective(collectiveArgs) 56 | 57 | def barrier(self, collectiveArgs, name="dummy"): 58 | self.mock_collective(collectiveArgs) 59 | 60 | def sync_barrier(self, collectiveArgs, desc="world"): 61 | self.barrier(collectiveArgs, name=desc) 62 | 63 | def mock_collective(self, collectiveArgs): 64 | # Mock this function to change collectiveArgs values. 65 | return collectiveArgs 66 | 67 | def get_reduce_op(self, opName): 68 | pass 69 | 70 | # Compute functions 71 | 72 | def gemm(self, collectiveArgs): 73 | pass 74 | 75 | # Memory related 76 | 77 | def get_mem_size(self, collectiveArgs): 78 | pass 79 | 80 | def alloc_embedding_tables(self, n, m, curRankDevice, dtype): 81 | pass 82 | 83 | def alloc_empty(self, sizeArr, dtype, curRankDevice): 84 | pass 85 | 86 | def clear_memory(self, collectiveArgs): 87 | pass 88 | 89 | # Getting world-size and other information. 90 | 91 | def get_local_rank(self): 92 | return self.local_rank 93 | 94 | def get_global_rank(self): 95 | return self.global_rank 96 | 97 | def get_world_size(self): 98 | return self.world_size 99 | 100 | def get_device(self): 101 | return self.device 102 | 103 | def get_hw_device(self): 104 | return self.device 105 | 106 | def get_default_group(self): 107 | return self.group 108 | 109 | def get_groups(self): 110 | pass 111 | 112 | # Init functions 113 | 114 | def initialize_backend(self, master_ip, master_port, backend="gloo"): 115 | pass 116 | 117 | def benchmark_comms(self, benchTime, commsParams): 118 | pass 119 | 120 | def alloc_ones( 121 | self, sizeArr, curRankDevice="cpu", dtype=torch.int32, scaleFactor=1.0 122 | ): 123 | ipTensor = torch.ones(sizeArr, device=curRankDevice, dtype=dtype) 124 | if scaleFactor != 1.0: 125 | ipTensor = ipTensor * scaleFactor 126 | return ipTensor 127 | 128 | def alloc_random( 129 | self, sizeArr, curRankDevice="cpu", dtype=torch.int32, scaleFactor=1.0 130 | ): 131 | return self.alloc_ones( 132 | sizeArr, "cpu", dtype, 1.0 133 | ) # just return arrays of 1 for testing 134 | -------------------------------------------------------------------------------- /train/comms/pt/tests/test_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains test classes with default values for comms unit tests. 3 | Feel free to add additional classes or modify existing ones as needed for new tests. 4 | """ 5 | 6 | from param_bench.train.comms.pt.comms_utils import commsArgs 7 | 8 | 9 | class testArgs: # default args to run tests with 10 | def __init__(self): 11 | self.trace_file = "" 12 | self.use_remote_trace = False 13 | self.dry_run = False 14 | self.auto_shrink = False 15 | self.max_msg_cnt = 0 # 0 means no limit 16 | self.num_msg = 0 17 | self.z = 0 18 | self.no_warm_up = True 19 | self.allow_ops = "" 20 | self.output_path = "/tmp/paramReplayedTrace" 21 | self.colls_per_batch = -1 22 | self.use_timestamp = False 23 | self.rebalance_policy = "" 24 | 25 | 26 | class commsParamsTest: 27 | def __init__(self): 28 | # A holding object for common input parameters, add as needed to test 29 | self.nw_stack = "pytorch_dist" 30 | self.dtype = "int" 31 | self.backend = "nccl" 32 | self.device = "cpu" 33 | self.blockingFlag = 1 34 | # quantization 35 | self.bitwidth = 32 36 | self.quant_a2a_embedding_dim = 1 37 | self.quant_threshold = 1 38 | self.dcheck = 1 39 | self.num_pgs = 1 40 | 41 | 42 | class bootstrap_info_test: 43 | def __init__(self): 44 | self.global_rank = 0 45 | self.local_rank = 0 46 | self.world_size = 16 47 | 48 | self.master_ip = "localhost" 49 | self.master_port = "25555" 50 | self.num_tpu_cores = 16 51 | 52 | 53 | def createCommsArgs(**kwargs) -> commsArgs: 54 | """ 55 | Test utility to create comms args from a dict of values. 56 | """ 57 | curComm = commsArgs() 58 | for key, value in kwargs.items(): 59 | setattr(curComm, key, value) 60 | 61 | return curComm 62 | -------------------------------------------------------------------------------- /train/compute/pt/README.md: -------------------------------------------------------------------------------- 1 | # PARAM benchmark -- compute benchmarks 2 | 3 | Unified compute kernel benchmarks for DLRM and other important AI workloads 4 | under PyTorch interface. 5 | 6 | Currently there are three kernels are identified, 7 | * GEMM (or MatMul) : Measure GEMM performance for matrix Z(m,n) = X(m,k) x Y(k, n) 8 | * MLP (multilayer perceptron) : measure a series of FC layer performance 9 | * EmbeddingBag : Measure the EmbeddingBag performance for table lookup 10 | 11 | The benchmark is developed to measure the performance of individual 12 | operation or kernel and used to measure the performance across 13 | different platforms, such as CPU, GPU, or TPU. 14 | 15 | The TPU implementation is through PyTorch/XLA. 16 | 17 | ## Usage 18 | 19 | A driver (`driver.py`) is developed, which can be used to run different kernels. 20 | For each kernel, one or more datasets have been defined as 'A', 'B', 'C', etc. 21 | 22 | ```bash 23 | python3 driver.py -h 24 | usage: driver.py [-h] [--warmups WARMUPS] [--steps STEPS] --device {cpu,gpu,tpu} {gemm,emb,linear} ... 25 | 26 | Measuring the Compute Kernel Performance Using PyTorch 27 | 28 | optional arguments: 29 | -h, --help show this help message and exit 30 | --warmups WARMUPS warmup times 31 | --steps STEPS repeat times 32 | --device {cpu,gpu,tpu} 33 | valid devices 34 | 35 | kernels: 36 | {gemm,emb,linear} 37 | gemm measure mm performance (m,k)*(k,n)=(m,n) 38 | emb measure EmbeddingBag performance 39 | linear measure mlp performance 40 | ``` 41 | 42 | ### Testing GEMM : 43 | 44 | ```bash 45 | python3 driver.py --steps=100 --device='cpu' gemm --dataset='A' 46 | 47 | Measuring the performance of gemm on device = cpu 48 | Steps = 100 warmups = 10 49 | with matrix dataset A , Data type: float32 50 | 51 | ---------------------------------------------------------------- 52 | M N K Time(s) Rate(GF/s) 53 | ---------------------------------------------------------------- 54 | 128, 4096, 4096, 0.519193 827.240 55 | 256, 4096, 4096, 1.005778 854.058 56 | 512, 4096, 4096, 2.214854 775.666 57 | 1024, 4096, 4096, 3.388758 1013.933 58 | 128, 1024, 1024, 0.555641 48.311 59 | 256, 1024, 1024, 0.145774 368.291 60 | 512, 1024, 1024, 0.177422 605.189 61 | 1024, 1024, 1024, 0.215082 998.447 62 | 63 | ``` 64 | 65 | ### Testing EmbeddingBag 66 | ```bash 67 | python3 driver.py --steps=100 --device='cpu' emb --dataset='A' 68 | 69 | Measuring the performance of emb on device = cpu 70 | Steps = 10 warmup = 1 71 | with emb data A. 72 | --------------------------------------------------------------------------------- 73 | Features embdim nnz batch Time(s)/step Data(MB) BW(GB/s) 74 | --------------------------------------------------------------------------------- 75 | 14000000, 128, 30, 2048, 0.002067, 31.5, 15.222 76 | 14000000, 128, 30, 4096, 0.004611, 62.9, 13.644 77 | 14000000, 128, 30, 8192, 0.006464, 125.8, 19.466 78 | 14000000, 128, 30, 16384, 0.009102, 251.7, 27.649 79 | 80 | ``` 81 | Note that on TPU, due to the current performance concern of EmbeddingBag, we 82 | also support an alternative implementation, XlaEmbeddingBag, which can be 83 | invoked through --usexlabag 84 | 85 | Example: Measure the performance of a MLP with 18 hidden layer, layer size 1024 86 | ```bash 87 | python pytorch_linear.py --device gpu --layer-num 18 --batch-size 128 --input-size 1024 \ 88 | --hidden-size 1024 --output-size 1024 --steps 100 \ 89 | --dtype=float16 --optimizer-type=sgd 90 | ``` 91 | 92 | ### Testing MLP Linear 93 | ```bash 94 | python3 driver.py --steps=100 --device='cpu' linear --dataset='A' 95 | 96 | Measuring the performance of linear on device = cpu 97 | Steps = 10 warmups = 1 98 | with linear dataset A , Data type: float 99 | -------------------------------------------------------------------------------- 100 | #Layer Input Hidden Output Batch Time(s)/step QPS Rate(GF/s) 101 | -------------------------------------------------------------------------------- 102 | 103 | 18, 1024, 1024, 1024, 128, 0.344426, 371.6, 46.8 104 | 18, 1024, 1024, 1024, 256, 0.206910, 1237.3, 155.7 105 | 18, 1024, 1024, 1024, 512, 0.279407, 1832.5, 230.6 106 | 107 | ``` 108 | 109 | In addition, you can run individual kernels using 110 | ```bash 111 | python3 pytorch_gemm.py ... 112 | ``` 113 | ```bash 114 | python3 pytorch_emb.py ... 115 | ``` 116 | ```bash 117 | python3 pytorch_linear.py ... 118 | ``` 119 | -------------------------------------------------------------------------------- /train/compute/pt/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # gemm tuple shape (M, N, K) 7 | gemm_A = [ 8 | (128, 4096, 4096), 9 | (256, 4096, 4096), 10 | (512, 4096, 4096), 11 | (1024, 4096, 4096), 12 | (128, 1024, 1024), 13 | (256, 1024, 1024), 14 | (512, 1024, 1024), 15 | (1024, 1024, 1024), 16 | (4096, 4096, 128), 17 | (4096, 4096, 256), 18 | (4096, 4096, 512), 19 | (4096, 4096, 1024), 20 | (1024, 1024, 128), 21 | (1024, 1024, 256), 22 | (1024, 1024, 512), 23 | ] 24 | 25 | gemm_B = [ 26 | (128, 4096, 40928), 27 | (256, 4096, 40928), 28 | (512, 4096, 40928), 29 | (1024, 4096, 40928), 30 | (128, 40928, 4096), 31 | (256, 40928, 4096), 32 | (512, 40928, 4096), 33 | (1024, 40928, 4096), 34 | (128, 1024, 2000), 35 | (256, 1024, 2000), 36 | (512, 1024, 2000), 37 | (1024, 1024, 2000), 38 | (1024, 2000, 128), 39 | (1024, 2000, 256), 40 | (1024, 2000, 512), 41 | (1024, 2000, 1024), 42 | (4096, 40928, 128), 43 | (4096, 40928, 256), 44 | (4096, 40928, 512), 45 | (4096, 40928, 1024), 46 | ] 47 | 48 | gemm_C = [ 49 | (1024, 1024, 64), 50 | (1024, 64, 1024), 51 | (1024, 4096, 1024), 52 | (1024, 1024, 4096), 53 | ] 54 | 55 | # emb tuple (features, embdim, nnz, batch) 56 | emb_A = [ 57 | (14000000, 128, 30, 512), 58 | (14000000, 128, 30, 1024), 59 | (14000000, 128, 30, 2048), 60 | (14000000, 128, 30, 4096), 61 | (14000000, 128, 30, 8192), 62 | (14000000, 128, 30, 16384), 63 | (14000000, 128, 30, 32768), 64 | (14000000, 128, 30, 65536), 65 | (26000000, 128, 30, 512), 66 | (26000000, 128, 30, 1024), 67 | (26000000, 128, 30, 2048), 68 | (26000000, 128, 30, 4096), 69 | (26000000, 128, 30, 8192), 70 | (26000000, 128, 30, 16384), 71 | (26000000, 128, 30, 32768), 72 | (26000000, 128, 30, 65536), 73 | ] 74 | 75 | emb_B = [ 76 | (4800000, 56, 34, 2048), 77 | (4800000, 56, 34, 4096), 78 | (4800000, 56, 34, 8192), 79 | (4800000, 56, 34, 16384), 80 | (4800000, 56, 34, 32768), 81 | (4800000, 56, 34, 65536), 82 | ] 83 | 84 | # mlp tuple (layer-num, input-size, hidden-size, output-size, batch-size) 85 | mlp_A = [ 86 | (18, 1024, 1024, 1024, 128), 87 | (18, 1024, 1024, 1024, 256), 88 | (18, 1024, 1024, 1024, 512), 89 | (18, 1024, 1024, 1024, 1024), 90 | (18, 1024, 1024, 1024, 2048), 91 | (18, 1024, 1024, 1024, 4096), 92 | (18, 4096, 4096, 4096, 128), 93 | (18, 4096, 4096, 4096, 256), 94 | (18, 4096, 4096, 4096, 512), 95 | (18, 4096, 4096, 4096, 1024), 96 | (18, 4096, 4096, 4096, 2048), 97 | (18, 4096, 4096, 4096, 4096), 98 | ] 99 | -------------------------------------------------------------------------------- /train/compute/pt/driver.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import dataset 7 | import pytorch_emb as kemb 8 | import pytorch_gemm as kgemm 9 | import pytorch_linear as klinear 10 | 11 | 12 | def main() -> None: 13 | import argparse 14 | 15 | parser = argparse.ArgumentParser( 16 | description="Measuring the Compute Kernel Performance Using PyTorch" 17 | ) 18 | parser.add_argument("--warmups", type=int, default=10, help="warmup times") 19 | parser.add_argument("--steps", type=int, default=100, help="repeat times") 20 | parser.add_argument( 21 | "--device", 22 | type=str, 23 | choices=["cpu", "gpu", "tpu"], 24 | required=True, 25 | help="valid devices", 26 | ) 27 | 28 | subparsers = parser.add_subparsers(title="kernels", dest="kernel") 29 | subparsers.required = True 30 | 31 | parser_gemm = subparsers.add_parser( 32 | "gemm", help="measure mm performance (m,k)*(k,n)=(m,n)" 33 | ) 34 | parser_gemm.add_argument("-t", "--dtype", type=str, default="float32") 35 | parser_gemm.add_argument("-d", "--dataset", choices=["A", "B", "C"], default="A") 36 | 37 | parser_emb = subparsers.add_parser("emb", help="measure EmbeddingBag performance") 38 | parser_emb.add_argument("-d", "--dataset", choices=["A", "B"], default="A") 39 | parser_emb.add_argument("--randomseed", type=int, default=0) 40 | parser_emb.add_argument( 41 | "--usexlabag", action="store_true", help="use xlabad instead of embeddingbag" 42 | ) 43 | parser_emb.add_argument( 44 | "--alpha", default=0.0, help="Zipf param. Use uniform if == 0.0" 45 | ) 46 | 47 | parser_linear = subparsers.add_parser("linear", help="measure mlp performance") 48 | parser_linear.add_argument("--optimizer", action="store_true") 49 | parser_linear.add_argument( 50 | "-t", 51 | "--dtype", 52 | default="float", 53 | help="data type", 54 | choices=["float", "float16", "bfloat16"], 55 | ) 56 | parser_linear.add_argument("-d", "--dataset", choices=["A"], default="A") 57 | parser_linear.add_argument("--debug", action="store_false", default=False) 58 | parser_linear.add_argument("--fw-only", action="store_false", default=False) 59 | parser.add_argument("--set-to-none", action="store_false", default=False) 60 | parser.add_argument("--explicit-cast", action="store_true", default=True) 61 | parser_linear.add_argument( 62 | "--optimizer-type", 63 | default="sgd", 64 | help="Optimizer: SGD", 65 | choices=["sgd", "adagrad"], 66 | ) 67 | 68 | args = parser.parse_args() 69 | 70 | print("Measuring the performance of ", args.kernel, " on device = ", args.device) 71 | print("Steps = ", args.steps, " warmups = ", args.warmups) 72 | if args.kernel == "gemm": 73 | print("with matrix dataset ", args.dataset, ", Data type: ", args.dtype) 74 | print(" ") 75 | if args.dataset == "A": 76 | kgemm.run(args, dataset.gemm_A) 77 | elif args.dataset == "B": 78 | kgemm.run(args, dataset.gemm_B) 79 | else: 80 | kgemm.run(args, dataset.gemm_C) 81 | 82 | elif args.kernel == "emb": 83 | print("with emb dataset ", args.dataset) 84 | if args.dataset == "A": 85 | kemb.run(args, dataset.emb_A) 86 | elif args.dataset == "B": 87 | kemb.run(args, dataset.emb_B) 88 | 89 | else: 90 | print("with linear dataset ", args.dataset, ", Data type: ", args.dtype) 91 | if args.dataset == "A": 92 | ds = [] 93 | for i in range(len(dataset.mlp_A)): 94 | layers_size = [] 95 | ( 96 | layer_num, 97 | input_size, 98 | hidden_size, 99 | output_size, 100 | batch_size, 101 | ) = dataset.mlp_A[i] 102 | layers_size.append(input_size) 103 | for _ in range(layer_num): 104 | layers_size.append(hidden_size) 105 | layers_size.append(output_size) 106 | 107 | ds.append((layers_size, batch_size)) 108 | 109 | klinear.run(args, ds) 110 | 111 | 112 | if __name__ == "__main__": 113 | main() # pragma: no cover 114 | -------------------------------------------------------------------------------- /train/compute/pt/pytorch_cutlass.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import sys 7 | import time 8 | 9 | import torch 10 | 11 | 12 | def measure_blas(a, b, steps): 13 | global c 14 | torch.cuda.synchronize() 15 | start = time.perf_counter() 16 | for _ in range(steps): 17 | c = torch.mm(a, b) 18 | torch.cuda.synchronize() 19 | end = time.perf_counter() 20 | c.to("cpu") 21 | return end - start 22 | 23 | 24 | def measure_tlass(a, b, steps): 25 | torch.ops.load_library("//caffe2/torch/fb/cutlass:cutlass_gemm") 26 | 27 | global c 28 | torch.cuda.synchronize() 29 | start = time.perf_counter() 30 | for _ in range(steps): 31 | c = torch.ops.fb.mm(a, b) 32 | torch.cuda.synchronize() 33 | end = time.perf_counter() 34 | c.to("cpu") 35 | return end - start 36 | 37 | 38 | def run_single(args, m, n, k, func): 39 | dtype = args.dtype 40 | warmups = args.warmups 41 | steps = args.steps 42 | 43 | dt = torch.float32 44 | if dtype == "float16" or dtype == "half": 45 | dt = torch.float16 46 | elif dtype == "bfloat16": 47 | dt = torch.bfloat16 48 | 49 | torch.manual_seed(0) 50 | 51 | elap = 0.0 52 | 53 | a = torch.randn( 54 | m, 55 | k, 56 | ).to(dt) 57 | b = torch.randn(k, n).to(dt) 58 | c = torch.zeros(m, n).to(dt) 59 | 60 | if torch.cuda.is_available(): 61 | # ncuda = torch.cuda.device_count() 62 | # print("There are {} cuda devices".format(ncuda)) 63 | # print("The first cuda device name is {} ".format(torch.cuda.get_device_name())) 64 | cuda0 = torch.device("cuda:0") 65 | with torch.cuda.device(cuda0): 66 | acuda = a.to(cuda0) 67 | bcuda = b.to(cuda0) 68 | if func == "blas": 69 | measure_blas(acuda, bcuda, warmups) 70 | elap = measure_blas(acuda, bcuda, steps) 71 | else: 72 | measure_tlass(acuda, bcuda, warmups) 73 | elap = measure_tlass(acuda, bcuda, steps) 74 | else: 75 | print("CUDA is not available") 76 | sys.exit(1) 77 | 78 | return elap 79 | 80 | 81 | def run(args, dataset): 82 | print("----------------------------------------------------------------") 83 | print(" M N K Time(s) Rate(TF/s)") 84 | print("----------------------------------------------------------------") 85 | for i in range(len(dataset)): 86 | m, n, k = dataset[i] 87 | elap = run_single(args, m, n, k, "blas") 88 | elap /= args.steps 89 | print( 90 | "{:10}, {:10}, {:10}, {:10.6f} {:.3f} ".format( 91 | m, n, k, elap, m * n * k * 2 * 1.0 / elap / 1.0e12 92 | ) 93 | ) 94 | elap = run_single(args, m, n, k, "tlass") 95 | elap /= args.steps 96 | print( 97 | "{:10}, {:10}, {:10}, {:10.6f} {:.3f} ".format( 98 | m, n, k, elap, m * n * k * 2 * 1.0 / elap / 1.0e12 99 | ) 100 | ) 101 | 102 | 103 | def main() -> None: 104 | import argparse 105 | 106 | parser = argparse.ArgumentParser( 107 | description="Measure and compare the performance of GEMM cuBlas and cuTlass" 108 | ) 109 | # model related parameters 110 | parser.add_argument("-m", "--msize", type=int, default=1024) 111 | parser.add_argument("-n", "--nsize", type=int, default=1024) 112 | parser.add_argument("-k", "--ksize", type=int, default=1024) 113 | parser.add_argument("-t", "--dtype", type=str, default="float32") 114 | parser.add_argument("--steps", type=int, default=100) 115 | parser.add_argument("--warmups", type=int, default=10) 116 | args = parser.parse_args() 117 | 118 | d = [(args.msize, args.nsize, args.ksize)] 119 | run(args, d) 120 | 121 | 122 | if __name__ == "__main__": 123 | main() # pragma: no cover 124 | -------------------------------------------------------------------------------- /train/compute/pt/pytorch_gemm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import sys 7 | import time 8 | 9 | import torch 10 | 11 | 12 | def measure_cpu(a, b, steps): 13 | global c 14 | start = time.perf_counter() 15 | for i in range(steps): 16 | c = torch.mm(a, b) 17 | end = time.perf_counter() 18 | c.to("cpu") 19 | return end - start 20 | 21 | 22 | def measure_gpu(a, b, steps): 23 | global c 24 | torch.cuda.synchronize() 25 | start = time.perf_counter() 26 | for i in range(steps): 27 | c = torch.mm(a, b) 28 | torch.cuda.synchronize() 29 | end = time.perf_counter() 30 | c.to("cpu") 31 | return end - start 32 | 33 | 34 | def measure_xla(a, b, steps): 35 | import torch_xla 36 | 37 | def sync(tensor, dev): 38 | torch_xla._XLAC._xla_sync_multi( 39 | [tensor], devices=[str(dev)], wait=True, sync_xla_data=True 40 | ) 41 | 42 | c = torch.mm(a, b) 43 | 44 | start = time.perf_counter() 45 | for _ in range(steps): 46 | # Add data dependency to prevent loop elimination 47 | # The PyTorch/XLA lazy evaluation will eliminate the loop 48 | # Simplier data dependency will not work 49 | b[0] = torch.min(c[0], b[0]) 50 | c = torch.min(c, torch.mm(a, b)) 51 | 52 | sync(c, c.device) 53 | end = time.perf_counter() 54 | # c.to('cpu') 55 | return end - start 56 | 57 | 58 | def run_single(args, m, n, k): 59 | dtype = args.dtype 60 | device = args.device 61 | warmups = args.warmups 62 | steps = args.steps 63 | 64 | dt = torch.float32 65 | if dtype == "float16" or dtype == "half": 66 | dt = torch.float16 67 | elif dtype == "bfloat16": 68 | dt = torch.bfloat16 69 | elif dtype == "tf32": 70 | torch.backends.cudnn.allow_tf32 = True 71 | torch.backends.cuda.matmul.allow_tf32 = True 72 | 73 | torch.manual_seed(0) 74 | 75 | elap = 0.0 76 | 77 | a = torch.randn(m, k).to(dt) 78 | b = torch.randn(k, n).to(dt) 79 | c = torch.zeros(m, n).to(dt) 80 | 81 | if device == "cpu": 82 | measure_cpu(a, b, warmups) 83 | elap = measure_cpu(a, b, steps) 84 | 85 | elif device == "gpu": 86 | if torch.cuda.is_available(): 87 | # ncuda = torch.cuda.device_count() 88 | # print("There are {} cuda devices".format(ncuda)) 89 | # print("The first cuda device name is {} ".format(torch.cuda.get_device_name())) 90 | cuda0 = torch.device("cuda:0") 91 | with torch.cuda.device(cuda0): 92 | acuda = a.to(cuda0) 93 | bcuda = b.to(cuda0) 94 | measure_gpu(acuda, bcuda, warmups) 95 | elap = measure_gpu(acuda, bcuda, steps) 96 | else: 97 | print("CUDA is not available") 98 | sys.exit(1) 99 | 100 | else: 101 | # import torch_xla 102 | import torch_xla.core.xla_model as xm 103 | 104 | # alldev = xm.get_xla_supported_devices() 105 | # allrealdev = xm.xla_real_devices(alldev) 106 | # print("Found {0} XLA devices: {1}".format(len(allrealdev), allrealdev)) 107 | 108 | dev = xm.xla_device() 109 | a = a.to(dev) 110 | b = b.to(dev) 111 | c = c.to(dev) 112 | measure_xla(a, b, warmups) 113 | xm.mark_step() 114 | elap = measure_xla(a, b, steps) 115 | xm.mark_step() 116 | 117 | return elap 118 | 119 | 120 | def run(args, dataset): 121 | print("----------------------------------------------------------------") 122 | print(" M N K Time(s) Rate(TF/s)") 123 | print("----------------------------------------------------------------") 124 | for i in range(len(dataset)): 125 | m, n, k = dataset[i] 126 | elap = run_single(args, m, n, k) 127 | elap /= args.steps 128 | print( 129 | "{:10}, {:10}, {:10}, {:10.6f} {:.3f} ".format( 130 | m, n, k, elap, m * n * k * 2 * 1.0 / elap / 1.0e12 131 | ) 132 | ) 133 | 134 | 135 | def main() -> None: 136 | import argparse 137 | 138 | parser = argparse.ArgumentParser( 139 | description="Measure the performance of GEMM using mm, or matmul" 140 | ) 141 | # model related parameters 142 | parser.add_argument("-m", "--msize", type=int, default=1024) 143 | parser.add_argument("-n", "--nsize", type=int, default=1024) 144 | parser.add_argument("-k", "--ksize", type=int, default=1024) 145 | parser.add_argument("-t", "--dtype", type=str, default="float32") 146 | parser.add_argument( 147 | "-d", "--device", choices=["cpu", "gpu", "tpu"], type=str, default="cpu" 148 | ) 149 | parser.add_argument("--steps", type=int, default=100) 150 | parser.add_argument("--warmups", type=int, default=10) 151 | args = parser.parse_args() 152 | 153 | d = [(args.msize, args.nsize, args.ksize)] 154 | run(args, d) 155 | 156 | 157 | if __name__ == "__main__": 158 | main() # pragma: no cover 159 | -------------------------------------------------------------------------------- /train/compute/python/.gitignore: -------------------------------------------------------------------------------- 1 | /build 2 | /dist 3 | *.so 4 | *.egg-info 5 | /*.json 6 | /*.log 7 | __pycache__ 8 | -------------------------------------------------------------------------------- /train/compute/python/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/param/365dc0409e6f66a1823199a62553a1ed558693b6/train/compute/python/__init__.py -------------------------------------------------------------------------------- /train/compute/python/development.md: -------------------------------------------------------------------------------- 1 | # PARAM Compute Benchmark Development 2 | 3 | **For installation and basic usage instructions, please see [READM.md](README.md).** 4 | 5 | ## File Structures 6 | 7 | Directories 8 | 9 | * [`python`](.) 10 | * Base dir for Python benchmarks, including tool scripts. 11 | * [`python/examples`](./examples) 12 | * Example scripts and configuration files. 13 | * [`python/lib`](./lib) 14 | * Benchmark library modules and utilities. 15 | * [`python/pytorch`](./pytorch) 16 | * PyTorch framework benchmark scripts. 17 | * [`python/test`](./test) 18 | * Unit tests and test config files. 19 | * [`python/tools`](./tools) 20 | * General tool scripts. 21 | * [`python/workloads`](./workloads) 22 | * Implementation of workloads (operators). 23 | 24 | ML framework specific modules and files are in separate directories (e.g. `pytorch`) under these top level directories. 25 | 26 | Because the benchmark library and workloads are intended to be used both inside and outside of Facebook, we need to make sure that they work consistently in both scenarios. **Within the benchmark library package itself, we prefer to use relative imports**, for example: 27 | 28 | ```python 29 | from ..config import OperatorConfig 30 | from ..iterator import ConfigIterator 31 | from ..operator import OperatorInterface 32 | ``` 33 | This allows the top level package name to change without affecting the library code itself. 34 | 35 | ## Operator Interface 36 | The [`OperatorInterface`](lib/operator.py) specifies the interface each workload should support. At a minimum it should implement the `forward(*args, **kwargs)` method. 37 | 38 | * `build(*args, **kwargs)`: [optional] 39 | * initialize and constructs all necessary data and objects to run the operator workload. It takes positional and keyword arguments from the configuration file. 40 | * `cleanup()`: [optional] 41 | * release and delete any data and objects retained by this operator, its state should reset to before `build()` is called. This is called after a benchmark is run, so subsequent benchmarks do not run out of resource. 42 | * `forward(*args, **kwargs)`: [required] 43 | * runs the forward pass of the operator and stores the output for running `backward()`. 44 | * `create_grad()`: [optional] 45 | * create the gradient needed to run the `backward()` pass. This step is explicit to avoid counting this part in the benchmark latency for the backward pass. 46 | * `backward()`: [optional] 47 | * Use the result from `forward()` and gradient generated in `create_grad()` to run the backward pass. 48 | 49 | ### Auto Discovery of Workloads 50 | Python `pkgutil.iter_modules` provides a mechanism for discovering and importing modules dynamically. This allows adding workloads through the following simple steps: 51 | * Create or add to an operator workload python file in [`workloads`](workloads) directory 52 | * Implement the [`OperatorInterface`](lib/operator.py) 53 | * Register the new operator through one of the following 54 | * [`register_operator(name: str, operator: OperatorInterface)`](lib/operator.py) 55 | * [`register_operators(op_dict: Dict[str, OperatorInterface])`](lib/operator.py) 56 | 57 | The benchmark tool script will be able to load configuration files and instantiate corresponding operators for benchmarking. Two categories of of operators: 58 | * PyTorch native 59 | * Operators have no dependencies other than official PyTorch release. 60 | * External 61 | * Operators require additional installation. 62 | 63 | For users who do not have certain external operators in their environment, automatically importing these can cause errors. Auto import will try/catch these errors and skip these operators. 64 | 65 | ## Configuration Iterator 66 | Given a list of configurations (**build** or **input**), we need some mechanism to iterate over them. The overall logic is simple (**for illustration, not actual code**): 67 | 68 | ```python 69 | for build in build_configs: 70 | build_args, build_kwargs = materialize_config(build) 71 | op = Op.build(build_args, build_kwargs) 72 | for input in input_configs: 73 | input_args, input_kwargs = materialize_config(input) 74 | op.forward(input_args, input_kwargs) 75 | ``` 76 | 77 | There are some finer details: 78 | * Often we want to quickly generate many variations of build and input configurations without explicitly specifying each of them. This demands some mechanism for [**macros**](#macros). 79 | * The configuration is only a specification, further it may need to be expanded (if using macro) before materializing or generating the data. 80 | * Current implementations: 81 | * `DefaultConfigIterator` 82 | * `RangeConfigIterator` 83 | 84 | If existing configuration iterators do not satisfy your use case, new iterator implementation that supports the [`ConfigIterator`](lib/iterator.py) interface can be registered using 85 | [`register_config_iterator(name: str, iterator_class: Type[ConfigIterator])`](lib/iterator.py). 86 | 87 | ### Macros 88 | Macros are for convenience to reduce the number of configurations to be specified manually. 89 | 90 | #### `__range__` 91 | **`__range__`** defines a list of attributes with range specification. 92 |
 93 | "__range__": ["attr_name_1",...]
 94 | 
95 | 96 | **Example** 97 | ```json 98 | "args": [ 99 | { 100 | "type": "tensor", 101 | "dtype": "float", 102 | "shape": [512, [512, 514, 1], 30], 103 | "__range__": ["shape"] 104 | } 105 | ] 106 | ``` 107 | In above example, the argument is a `tensor` type. It has `"__range__"` macro specifies the `"shape"` attribute has range values: `[512, [512, 514, 1], 30]`. The second value the shape is a list `[512, 514, 1]`, it's represents `[min, max, step]`. During configuration iteration, multiple configurations will be generated, each with a different `"shape"` attribute after expansion: 108 | * `[512, 512, 30]` 109 | * `[512, 513, 30]` 110 | * `[512, 514, 30]` 111 | 112 | `"__range__"` macro also works for non-numeric values like `bool`, `str`, etc. These values can be specified in a list, i.e., 113 | ```json 114 | { 115 | "type": "bool", 116 | "value": [true, false], 117 | "__range__": ["value"] 118 | } 119 | ``` 120 | 121 | #### `__copy__` 122 | **Only `tensor` data type in positional `"args"` is supported.** 123 | 124 | In some instances, we need to ensure certain values are consistent between two attributes. For example, the input of a `matmul` operator has two tensors of shapes `A = [m, n]` and `B = [j, k]` where `n == j` for the inputs to be valid. As each of these values can vary between each input configuration, to ensure `j = n`, `__copy__` macro is applied to the data type attributes after tensor shape `A` is specified and copies the value of `n` to the value of `j` in tensor shape `B`. 125 |
126 | "__copy__": [{"src_attr_name":[i, [j, k]]},...]
127 | 
128 | Defines a list of attributes and where to copy their values from. 129 | * `"src_attr_name"`: source attribute name 130 | * `i`: target element index 131 | * `j`: source **argument** index 132 | * `k`: source **element** index 133 | Copy value from source argument at `j`, element index `k`, to the current argument attribute element at index `i`. 134 | 135 | **Example** 136 | ```json 137 | "input": [ 138 | { 139 | "type": "tensor", 140 | "dtype": "float", 141 | "shape": [-1, 64, 128], 142 | "__copy__": [ 143 | { 144 | "shape": [0, [1, 2]] 145 | } 146 | ] 147 | }, 148 | { 149 | "type": "tensor", 150 | "dtype": "float", 151 | "shape": [8, 16, 32] 152 | } 153 | ] 154 | ``` 155 | In above example of a tensor argument, its shape's value at element index `0` (with a `-1` value), will get the value of argument a position `1`, and its `"shape"` attribute's value at element index `2` (with value '32'). After the copy macro is applied, the tensor argument at index `0`, will have shape `[32, 64, 128]`. 156 | 157 | ## Data Generator 158 | The role of the data generator is given a configuration specification, it generates actual data (scalar, boolean, string, tensor, etc.) for building or executing an operator. 159 | 160 | In current implementations we provide a default data generator that supports PyTorch data types (see [PyTorch Data Types](#pyTorch-data-types)): 161 | * [`PyTorch:DefaultDataGenerator`](lib/pytorch/data_impl.py) 162 | 163 | If needed, it's possible to implement custom data generators based on the [`DataGenerator`](lib/data.py) interface. They can be registered using 164 | [`register_data_generator(name: str, data_gen_class: Type[DataGenerator])`](lib/data.py). 165 | 166 | ## Timer 167 | Timer is essential in measuring operator latency. Some devices (GPU) are async and require special steps to run in blocking or synchronized mode. Depending on where the operator will run, the proper timer should be used: 168 | * CPU 169 | * GPU (PyTorch) 170 | 171 | In the future, we may support timers for other device types. 172 | -------------------------------------------------------------------------------- /train/compute/python/examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/param/365dc0409e6f66a1823199a62553a1ed558693b6/train/compute/python/examples/__init__.py -------------------------------------------------------------------------------- /train/compute/python/examples/cuda/ncu_args.txt: -------------------------------------------------------------------------------- 1 | --metrics dram__bytes.sum,l1tex__t_bytes.sum,lts__t_bytes.sum,sm__cycles_elapsed.avg,sm__cycles_elapsed.avg.per_second,sm__sass_thread_inst_executed_op_fadd_pred_on.sum,sm__sass_thread_inst_executed_op_ffma_pred_on.sum,sm__sass_thread_inst_executed_op_fmul_pred_on.sum 2 | -------------------------------------------------------------------------------- /train/compute/python/examples/pytorch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/param/365dc0409e6f66a1823199a62553a1ed558693b6/train/compute/python/examples/pytorch/__init__.py -------------------------------------------------------------------------------- /train/compute/python/examples/pytorch/configs/alex_net.json: -------------------------------------------------------------------------------- 1 | { 2 | "pytorch.model.alex_net": { 3 | "build_data_generator": "PyTorch:DefaultDataGenerator", 4 | "input_data_generator": "PyTorch:DefaultDataGenerator", 5 | "config": [ 6 | { 7 | "build": [], 8 | "input": [ 9 | { 10 | "args": [ 11 | { 12 | "dtype": "float", 13 | "shape": [ 14 | 128, 15 | 3, 16 | 224, 17 | 224 18 | ], 19 | "type": "tensor" 20 | } 21 | ] 22 | } 23 | ] 24 | } 25 | ] 26 | } 27 | } -------------------------------------------------------------------------------- /train/compute/python/examples/pytorch/configs/aten_ops.json: -------------------------------------------------------------------------------- 1 | { 2 | "aten::add": { 3 | "build_data_generator": "PyTorch:DefaultDataGenerator", 4 | "input_data_generator": "PyTorch:DefaultDataGenerator", 5 | "config": [ 6 | { 7 | "build": [ 8 | { 9 | "args": [ 10 | { 11 | "value": "aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor", 12 | "type": "str" 13 | } 14 | ] 15 | } 16 | ], 17 | "input": [ 18 | { 19 | "args": [ 20 | { 21 | "dtype": "float", 22 | "shape": [ 23 | 4, 24 | 4 25 | ], 26 | "type": "tensor" 27 | }, 28 | { 29 | "dtype": "float", 30 | "shape": [ 31 | 4, 32 | 4 33 | ], 34 | "type": "tensor" 35 | }, 36 | { 37 | "value": 1, 38 | "type": "int" 39 | } 40 | ] 41 | } 42 | ] 43 | } 44 | ] 45 | }, 46 | "aten::add_": { 47 | "build_data_generator": "PyTorch:DefaultDataGenerator", 48 | "input_data_generator": "PyTorch:DefaultDataGenerator", 49 | "config": [ 50 | { 51 | "build": [ 52 | { 53 | "args": [ 54 | { 55 | "value": "aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)", 56 | "type": "str" 57 | } 58 | ] 59 | } 60 | ], 61 | "input": [ 62 | { 63 | "args": [ 64 | { 65 | "dtype": "float", 66 | "shape": [ 67 | 4, 68 | 4 69 | ], 70 | "type": "tensor", 71 | "requires_grad": false 72 | }, 73 | { 74 | "dtype": "float", 75 | "shape": [ 76 | 4, 77 | 4 78 | ], 79 | "type": "tensor", 80 | "requires_grad": false 81 | }, 82 | { 83 | "value": 1, 84 | "type": "int" 85 | } 86 | ] 87 | } 88 | ] 89 | } 90 | ] 91 | }, 92 | "aten::matmul": { 93 | "build_data_generator": "PyTorch:DefaultDataGenerator", 94 | "input_data_generator": "PyTorch:DefaultDataGenerator", 95 | "config": [ 96 | { 97 | "build": [ 98 | { 99 | "args": [ 100 | { 101 | "value": "aten::matmul(Tensor self, Tensor other) -> Tensor", 102 | "type": "str" 103 | } 104 | ] 105 | } 106 | ], 107 | "input": [ 108 | { 109 | "args": [ 110 | { 111 | "dtype": "float", 112 | "shape": [ 113 | 4, 114 | 4 115 | ], 116 | "type": "tensor" 117 | }, 118 | { 119 | "dtype": "float", 120 | "shape": [ 121 | 4, 122 | 4 123 | ], 124 | "type": "tensor" 125 | } 126 | ] 127 | } 128 | ] 129 | } 130 | ] 131 | }, 132 | "aten::mul": { 133 | "build_data_generator": "PyTorch:DefaultDataGenerator", 134 | "input_data_generator": "PyTorch:DefaultDataGenerator", 135 | "config": [ 136 | { 137 | "build": [ 138 | { 139 | "args": [ 140 | { 141 | "value": "aten::mul.Tensor(Tensor self, Tensor other) -> Tensor", 142 | "type": "str" 143 | } 144 | ] 145 | } 146 | ], 147 | "input": [ 148 | { 149 | "args": [ 150 | { 151 | "dtype": "float", 152 | "shape": [ 153 | 4, 154 | 4 155 | ], 156 | "type": "tensor" 157 | }, 158 | { 159 | "dtype": "float", 160 | "shape": [ 161 | 4, 162 | 4 163 | ], 164 | "type": "tensor" 165 | } 166 | ] 167 | } 168 | ] 169 | } 170 | ] 171 | }, 172 | "aten::sum": { 173 | "build_data_generator": "PyTorch:DefaultDataGenerator", 174 | "input_data_generator": "PyTorch:DefaultDataGenerator", 175 | "config": [ 176 | { 177 | "build": [ 178 | { 179 | "args": [ 180 | { 181 | "value": "aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)", 182 | "type": "str" 183 | } 184 | ] 185 | } 186 | ], 187 | "input": [ 188 | { 189 | "args": [ 190 | { 191 | "dtype": "float", 192 | "shape": [ 193 | 4, 194 | 4 195 | ], 196 | "type": "tensor" 197 | }, 198 | { 199 | "type": "none" 200 | } 201 | ] 202 | } 203 | ] 204 | } 205 | ] 206 | }, 207 | "aten::linear": { 208 | "build_data_generator": "PyTorch:DefaultDataGenerator", 209 | "input_data_generator": "PyTorch:DefaultDataGenerator", 210 | "config": [ 211 | { 212 | "build": [ 213 | { 214 | "args": [ 215 | { 216 | "value": "aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor", 217 | "type": "str" 218 | } 219 | ] 220 | } 221 | ], 222 | "input": [ 223 | { 224 | "args": [ 225 | { 226 | "dtype": "float", 227 | "shape": [ 228 | 4, 229 | 4 230 | ], 231 | "type": "tensor" 232 | }, 233 | { 234 | "dtype": "float", 235 | "shape": [ 236 | 4, 237 | 4 238 | ], 239 | "type": "tensor" 240 | }, 241 | { 242 | "type": "none" 243 | } 244 | ] 245 | } 246 | ] 247 | } 248 | ] 249 | } 250 | } -------------------------------------------------------------------------------- /train/compute/python/examples/pytorch/configs/batch_example.json: -------------------------------------------------------------------------------- 1 | { 2 | "op_name": "torch.add", 3 | "build_id": "1:2", 4 | "op_info": { 5 | "input_data_generator": "PyTorch:DefaultDataGenerator", 6 | "config": [ 7 | { 8 | "input": [ 9 | { 10 | "id": "2_34", 11 | "args": [ 12 | { 13 | "dtype": "float", 14 | "shape": [ 15 | 256, 16 | 256 17 | ], 18 | "type": "tensor" 19 | }, 20 | { 21 | "dtype": "float", 22 | "shape": [ 23 | 256, 24 | 256 25 | ], 26 | "type": "tensor" 27 | } 28 | ] 29 | } 30 | ] 31 | } 32 | ] 33 | }, 34 | "run_options": { 35 | "device": "cuda", 36 | "pass_type": "forward", 37 | "warmup": 1, 38 | "iteration": 1, 39 | "out_stream": null, 40 | "resume_op_run_id": null 41 | } 42 | } -------------------------------------------------------------------------------- /train/compute/python/examples/pytorch/configs/llama2.json: -------------------------------------------------------------------------------- 1 | { 2 | "torch.mm": { 3 | "input_data_generator": "PyTorch:DefaultDataGenerator", 4 | "config": [ 5 | { 6 | "input": [ 7 | { 8 | "args": [ 9 | { 10 | "dtype": "bfloat16", 11 | "shape": [ 12 | 2048, 13 | 4096 14 | ], 15 | "type": "tensor" 16 | }, 17 | { 18 | "dtype": "bfloat16", 19 | "shape": [ 20 | 4096, 21 | 4096 22 | ], 23 | "type": "tensor" 24 | } 25 | ] 26 | } 27 | ] 28 | }, 29 | { 30 | "input": [ 31 | { 32 | "args": [ 33 | { 34 | "dtype": "bfloat16", 35 | "shape": [ 36 | 2048, 37 | 4096 38 | ], 39 | "type": "tensor" 40 | }, 41 | { 42 | "dtype": "bfloat16", 43 | "shape": [ 44 | 4096, 45 | 4096 46 | ], 47 | "type": "tensor" 48 | } 49 | ] 50 | } 51 | ] 52 | }, 53 | { 54 | "input": [ 55 | { 56 | "args": [ 57 | { 58 | "dtype": "bfloat16", 59 | "shape": [ 60 | 2048, 61 | 4096 62 | ], 63 | "type": "tensor" 64 | }, 65 | { 66 | "dtype": "bfloat16", 67 | "shape": [ 68 | 4096, 69 | 4096 70 | ], 71 | "type": "tensor" 72 | } 73 | ] 74 | } 75 | ] 76 | }, 77 | { 78 | "input": [ 79 | { 80 | "args": [ 81 | { 82 | "dtype": "bfloat16", 83 | "shape": [ 84 | 2048, 85 | 4096 86 | ], 87 | "type": "tensor" 88 | }, 89 | { 90 | "dtype": "bfloat16", 91 | "shape": [ 92 | 4096, 93 | 2048 94 | ], 95 | "type": "tensor" 96 | } 97 | ] 98 | } 99 | ] 100 | }, 101 | { 102 | "input": [ 103 | { 104 | "args": [ 105 | { 106 | "dtype": "bfloat16", 107 | "shape": [ 108 | 4096, 109 | 11008 110 | ], 111 | "type": "tensor" 112 | }, 113 | { 114 | "dtype": "bfloat16", 115 | "shape": [ 116 | 11008, 117 | 11008 118 | ], 119 | "type": "tensor" 120 | } 121 | ] 122 | } 123 | ] 124 | }, 125 | { 126 | "input": [ 127 | { 128 | "args": [ 129 | { 130 | "dtype": "bfloat16", 131 | "shape": [ 132 | 11008, 133 | 11008 134 | ], 135 | "type": "tensor" 136 | }, 137 | { 138 | "dtype": "bfloat16", 139 | "shape": [ 140 | 11008, 141 | 4096 142 | ], 143 | "type": "tensor" 144 | } 145 | ] 146 | } 147 | ] 148 | } 149 | ] 150 | } 151 | } 152 | -------------------------------------------------------------------------------- /train/compute/python/examples/pytorch/configs/mm.json: -------------------------------------------------------------------------------- 1 | { 2 | "torch.baddbmm": { 3 | "input_data_generator": "PyTorch:DefaultDataGenerator", 4 | "config": [ 5 | { 6 | "input": [ 7 | { 8 | "args": [ 9 | { 10 | "dtype": "float", 11 | "shape": [ 12 | 1, 13 | 1, 14 | 1 15 | ], 16 | "type": "tensor" 17 | }, 18 | { 19 | "dtype": "float", 20 | "shape": [ 21 | 1, 22 | 1, 23 | 1 24 | ], 25 | "type": "tensor" 26 | }, 27 | { 28 | "dtype": "float", 29 | "shape": [ 30 | 1, 31 | 1, 32 | 1 33 | ], 34 | "type": "tensor" 35 | } 36 | ], 37 | "kwargs": { 38 | "beta": { 39 | "type": "int", 40 | "value": 1 41 | }, 42 | "alpha": { 43 | "type": "int", 44 | "value": 1 45 | } 46 | } 47 | }, 48 | { 49 | "args": [ 50 | { 51 | "dtype": "float", 52 | "shape": [ 53 | 2, 54 | 1, 55 | 512 56 | ], 57 | "type": "tensor" 58 | }, 59 | { 60 | "dtype": "float", 61 | "shape": [ 62 | 2, 63 | 512, 64 | 512 65 | ], 66 | "type": "tensor" 67 | }, 68 | { 69 | "dtype": "float", 70 | "shape": [ 71 | 2, 72 | 512, 73 | 512 74 | ], 75 | "type": "tensor" 76 | } 77 | ], 78 | "kwargs": { 79 | "beta": { 80 | "type": "int", 81 | "value": 1 82 | }, 83 | "alpha": { 84 | "type": "int", 85 | "value": 1 86 | } 87 | } 88 | } 89 | ] 90 | } 91 | ] 92 | }, 93 | "torch.bmm": { 94 | "input_data_generator": "PyTorch:DefaultDataGenerator", 95 | "config": [ 96 | { 97 | "input": [ 98 | { 99 | "args": [ 100 | { 101 | "dtype": "float", 102 | "shape": [ 103 | 512, 104 | 1234, 105 | 30 106 | ], 107 | "type": "tensor" 108 | }, 109 | { 110 | "dtype": "float", 111 | "shape": [ 112 | 512, 113 | 30, 114 | 64 115 | ], 116 | "type": "tensor" 117 | } 118 | ] 119 | }, 120 | { 121 | "args": [ 122 | { 123 | "dtype": "float", 124 | "shape": [ 125 | 512, 126 | 64, 127 | 30 128 | ], 129 | "type": "tensor" 130 | }, 131 | { 132 | "dtype": "float", 133 | "shape": [ 134 | 512, 135 | 30, 136 | 1234 137 | ], 138 | "type": "tensor" 139 | } 140 | ] 141 | }, 142 | { 143 | "args": [ 144 | { 145 | "dtype": "float", 146 | "shape": [ 147 | 512, 148 | 1234, 149 | 64 150 | ], 151 | "type": "tensor" 152 | }, 153 | { 154 | "dtype": "float", 155 | "shape": [ 156 | 512, 157 | 64, 158 | 30 159 | ], 160 | "type": "tensor" 161 | } 162 | ] 163 | }, 164 | { 165 | "args": [ 166 | { 167 | "dtype": "float", 168 | "shape": [ 169 | 512, 170 | 64, 171 | 1234 172 | ], 173 | "type": "tensor" 174 | }, 175 | { 176 | "dtype": "float", 177 | "shape": [ 178 | 512, 179 | 1234, 180 | 30 181 | ], 182 | "type": "tensor" 183 | } 184 | ] 185 | }, 186 | { 187 | "args": [ 188 | { 189 | "dtype": "float", 190 | "shape": [ 191 | 2, 192 | 512, 193 | 512 194 | ], 195 | "type": "tensor" 196 | }, 197 | { 198 | "dtype": "float", 199 | "shape": [ 200 | 2, 201 | 512, 202 | 512 203 | ], 204 | "type": "tensor" 205 | } 206 | ] 207 | } 208 | ] 209 | } 210 | ] 211 | }, 212 | "torch.mm": { 213 | "input_data_generator": "PyTorch:DefaultDataGenerator", 214 | "config": [ 215 | { 216 | "input": [ 217 | { 218 | "args": [ 219 | { 220 | "dtype": "float", 221 | "shape": [ 222 | 1024, 223 | 1024 224 | ], 225 | "type": "tensor" 226 | }, 227 | { 228 | "dtype": "float", 229 | "shape": [ 230 | 1024, 231 | 1024 232 | ], 233 | "type": "tensor" 234 | } 235 | ] 236 | } 237 | ] 238 | } 239 | ] 240 | } 241 | } -------------------------------------------------------------------------------- /train/compute/python/examples/pytorch/configs/mm_range.json: -------------------------------------------------------------------------------- 1 | { 2 | "torch.baddbmm": { 3 | "input_iterator": "RangeConfigIterator", 4 | "input_data_generator": "PyTorch:DefaultDataGenerator", 5 | "config": [ 6 | { 7 | "input": [ 8 | { 9 | "args": [ 10 | { 11 | "dtype": "float", 12 | "shape": [ 13 | [ 14 | 1, 15 | 2, 16 | 1 17 | ], 18 | -1, 19 | -1 20 | ], 21 | "type": "tensor", 22 | "__range__": [ 23 | "shape" 24 | ], 25 | "__copy__": [ 26 | { 27 | "shape": [ 28 | 1, 29 | [ 30 | 1, 31 | 1 32 | ] 33 | ] 34 | }, 35 | { 36 | "shape": [ 37 | 2, 38 | [ 39 | 2, 40 | 2 41 | ] 42 | ] 43 | } 44 | ] 45 | }, 46 | { 47 | "dtype": "float", 48 | "shape": [ 49 | -1, 50 | [ 51 | 512, 52 | 513, 53 | 1 54 | ], 55 | [ 56 | 515, 57 | 518, 58 | 1 59 | ] 60 | ], 61 | "type": "tensor", 62 | "__range__": [ 63 | "shape" 64 | ], 65 | "__copy__": [ 66 | { 67 | "shape": [ 68 | 0, 69 | [ 70 | 0, 71 | 0 72 | ] 73 | ] 74 | } 75 | ] 76 | }, 77 | { 78 | "dtype": "float", 79 | "shape": [ 80 | -1, 81 | -1, 82 | [ 83 | 512, 84 | 515, 85 | 1 86 | ] 87 | ], 88 | "type": "tensor", 89 | "__range__": [ 90 | "shape" 91 | ], 92 | "__copy__": [ 93 | { 94 | "shape": [ 95 | 0, 96 | [ 97 | 0, 98 | 0 99 | ] 100 | ] 101 | }, 102 | { 103 | "shape": [ 104 | 1, 105 | [ 106 | 1, 107 | 2 108 | ] 109 | ] 110 | } 111 | ] 112 | } 113 | ], 114 | "kwargs": { 115 | "beta": { 116 | "type": "int", 117 | "value": 1 118 | }, 119 | "alpha": { 120 | "type": "int", 121 | "value": 1 122 | } 123 | } 124 | } 125 | ] 126 | } 127 | ] 128 | }, 129 | "torch.bmm": { 130 | "input_iterator": "RangeConfigIterator", 131 | "input_data_generator": "PyTorch:DefaultDataGenerator", 132 | "config": [ 133 | { 134 | "input": [ 135 | { 136 | "args": [ 137 | { 138 | "dtype": "float", 139 | "shape": [ 140 | 512, 141 | [ 142 | 512, 143 | 514, 144 | 1 145 | ], 146 | 30 147 | ], 148 | "type": "tensor", 149 | "__range__": [ 150 | "shape" 151 | ] 152 | }, 153 | { 154 | "dtype": "float", 155 | "shape": [ 156 | 512, 157 | 30, 158 | 64 159 | ], 160 | "type": "tensor" 161 | } 162 | ] 163 | }, 164 | { 165 | "args": [ 166 | { 167 | "dtype": "float", 168 | "shape": [ 169 | 512, 170 | 64, 171 | 30 172 | ], 173 | "type": "tensor" 174 | }, 175 | { 176 | "dtype": "float", 177 | "shape": [ 178 | 512, 179 | 30, 180 | [ 181 | 128, 182 | 130, 183 | 1 184 | ] 185 | ], 186 | "type": "tensor", 187 | "__range__": [ 188 | "shape" 189 | ] 190 | } 191 | ] 192 | } 193 | ] 194 | } 195 | ] 196 | } 197 | } -------------------------------------------------------------------------------- /train/compute/python/examples/pytorch/configs/resnet.json: -------------------------------------------------------------------------------- 1 | { 2 | "pytorch.model.resnet": { 3 | "build_data_generator": "PyTorch:DefaultDataGenerator", 4 | "input_data_generator": "PyTorch:DefaultDataGenerator", 5 | "config": [ 6 | { 7 | "build": [], 8 | "input": [ 9 | { 10 | "args": [ 11 | { 12 | "dtype": "float", 13 | "shape": [ 14 | 128, 15 | 3, 16 | 224, 17 | 224 18 | ], 19 | "type": "tensor" 20 | } 21 | ] 22 | } 23 | ] 24 | } 25 | ] 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /train/compute/python/examples/pytorch/configs/simple_add.json: -------------------------------------------------------------------------------- 1 | { 2 | "torch.add": { 3 | "input_data_generator": "PyTorch:DefaultDataGenerator", 4 | "config": [ 5 | { 6 | "input": [ 7 | { 8 | "args": [ 9 | { 10 | "dtype": "float", 11 | "shape": [ 12 | 256, 13 | 256 14 | ], 15 | "type": "tensor" 16 | }, 17 | { 18 | "dtype": "float", 19 | "shape": [ 20 | 256, 21 | 256 22 | ], 23 | "type": "tensor" 24 | } 25 | ] 26 | } 27 | ] 28 | } 29 | ] 30 | } 31 | } -------------------------------------------------------------------------------- /train/compute/python/examples/pytorch/configs/simple_add_range.json: -------------------------------------------------------------------------------- 1 | { 2 | "torch.add": { 3 | "input_iterator": "RangeConfigIterator", 4 | "input_data_generator": "PyTorch:DefaultDataGenerator", 5 | "config": [ 6 | { 7 | "input": [ 8 | { 9 | "args": [ 10 | { 11 | "dtype": "float", 12 | "shape": [ 13 | 128, 14 | 256 15 | ], 16 | "type": "tensor" 17 | }, 18 | { 19 | "value": [1, 10, 1], 20 | "__range__": [ 21 | "value" 22 | ], 23 | "type": "int" 24 | } 25 | ] 26 | }, 27 | { 28 | "args": [ 29 | { 30 | "dtype": "float", 31 | "shape": [ 32 | 256, 33 | 256 34 | ], 35 | "type": "tensor" 36 | }, 37 | { 38 | "value": [10, 20, 2], 39 | "__range__": [ 40 | "value" 41 | ], 42 | "type": "int" 43 | } 44 | ] 45 | } 46 | ] 47 | }, 48 | { 49 | "input": [ 50 | { 51 | "args": [ 52 | { 53 | "dtype": "float", 54 | "shape": [ 55 | 64, 56 | 256 57 | ], 58 | "type": "tensor" 59 | }, 60 | { 61 | "value": [30, 40, 1], 62 | "__range__": [ 63 | "value" 64 | ], 65 | "type": "int" 66 | } 67 | ] 68 | }, 69 | { 70 | "args": [ 71 | { 72 | "dtype": "float", 73 | "shape": [ 74 | 64, 75 | 256 76 | ], 77 | "type": "tensor" 78 | }, 79 | { 80 | "value": [40, 50, 2], 81 | "__range__": [ 82 | "value" 83 | ], 84 | "type": "int" 85 | } 86 | ] 87 | } 88 | ] 89 | } 90 | ] 91 | } 92 | } -------------------------------------------------------------------------------- /train/compute/python/examples/pytorch/configs/simple_mm.json: -------------------------------------------------------------------------------- 1 | { 2 | "torch.mm": { 3 | "input_data_generator": "PyTorch:DefaultDataGenerator", 4 | "config": [ 5 | { 6 | "input": [ 7 | { 8 | "args": [ 9 | { 10 | "dtype": "float", 11 | "shape": [ 12 | 1024, 13 | 1024 14 | ], 15 | "type": "tensor" 16 | }, 17 | { 18 | "dtype": "float", 19 | "shape": [ 20 | 1024, 21 | 1024 22 | ], 23 | "type": "tensor" 24 | } 25 | ] 26 | } 27 | ] 28 | } 29 | ] 30 | } 31 | } -------------------------------------------------------------------------------- /train/compute/python/examples/pytorch/configs/simple_mm_range.json: -------------------------------------------------------------------------------- 1 | { 2 | "torch.mm": { 3 | "input_iterator": "RangeConfigIterator", 4 | "input_data_generator": "PyTorch:DefaultDataGenerator", 5 | "config": [ 6 | { 7 | "input": [ 8 | { 9 | "args": [ 10 | { 11 | "dtype": "float", 12 | "shape": [ 13 | [128, 2048, 16], 14 | 256 15 | ], 16 | "__range__": [ 17 | "shape" 18 | ], 19 | "type": "tensor" 20 | }, 21 | { 22 | "dtype": "float", 23 | "shape": [ 24 | 256, 25 | 256 26 | ], 27 | "type": "tensor" 28 | } 29 | ] 30 | } 31 | ] 32 | } 33 | ] 34 | } 35 | } -------------------------------------------------------------------------------- /train/compute/python/examples/pytorch/configs/split_table_batched_embeddings_ops.json: -------------------------------------------------------------------------------- 1 | { 2 | "SplitTableBatchedEmbeddingBagsCodegen": { 3 | "build_iterator": "RangeConfigIterator", 4 | "input_iterator": "SplitTableBatchedEmbeddingBagsCodegenInputIterator", 5 | "build_data_generator": "PyTorch:DefaultDataGenerator", 6 | "input_data_generator": "SplitTableBatchedEmbeddingBagsCodegenInputDataGenerator", 7 | "config": [ 8 | { 9 | "build": [ 10 | { 11 | "args": [ 12 | { 13 | "type": "int", 14 | "name": "num_tables", 15 | "value": 1 16 | }, 17 | { 18 | "name": "rows", 19 | "type": "int", 20 | "value": 228582 21 | }, 22 | { 23 | "name": "dim", 24 | "type": "int", 25 | "value": 128 26 | }, 27 | { 28 | "type": "int", 29 | "name": "pooling", 30 | "value": 0 31 | }, 32 | { 33 | "type": "bool", 34 | "name": "weighted", 35 | "value": false 36 | }, 37 | { 38 | "type": "str", 39 | "name": "weights_precision", 40 | "value": "fp16" 41 | } 42 | ], 43 | "kwargs": { 44 | "optimizer": { 45 | "type": "str", 46 | "value": "exact_row_wise_adagrad" 47 | } 48 | } 49 | } 50 | ], 51 | "input": [ 52 | { 53 | "args": [ 54 | { 55 | "type": "int", 56 | "name": "batch_size", 57 | "value": 512 58 | }, 59 | { 60 | "name": "pooling_factor", 61 | "type": "int", 62 | "value": 50 63 | } 64 | ] 65 | } 66 | ] 67 | } 68 | ] 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /train/compute/python/examples/pytorch/run_op.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ...lib import pytorch as lib_pytorch 4 | from ...lib.config import make_op_config 5 | from ...lib.init_helper import load_modules 6 | from ...lib.pytorch.config_util import ( 7 | create_bench_config, 8 | create_op_args, 9 | create_type, 10 | ExecutionPass, 11 | get_benchmark_options, 12 | ) 13 | from ...lib.pytorch.op_executor import OpExecutor 14 | from ...workloads import pytorch as workloads_pytorch 15 | 16 | 17 | def main(): 18 | # Load PyTorch implementations for data generator and operators. 19 | load_modules(lib_pytorch) 20 | 21 | # Load PyTorch operator workloads. 22 | load_modules(workloads_pytorch) 23 | 24 | # Important to set num of threads to 1, some ops are not thread safe and 25 | # also improves measurement stability. 26 | torch.set_num_threads(1) 27 | 28 | op_name = "torch.mm" 29 | bench_config = create_bench_config(op_name) 30 | tensor_1 = create_type("tensor") 31 | tensor_1["shape"] = [128, 128] 32 | tensor_2 = create_type("tensor") 33 | tensor_2["shape"] = [128, 128] 34 | 35 | op_info = bench_config[op_name] 36 | 37 | # Add the two tensors as first and second positional args for the operator. 38 | input_config = create_op_args([tensor_1, tensor_2], {}) 39 | op_info["config"][0]["input"].append(input_config) 40 | print(op_info) 41 | 42 | # Get the default benchmark options 43 | run_options = get_benchmark_options() 44 | 45 | # By default, benchmark will run the forward pass. 46 | # By setting backward (which requires running forward pass), the benchmark 47 | # will run both forward and backward pass. 48 | run_options["pass_type"] = ExecutionPass.BACKWARD 49 | 50 | # Create OperatorConfig that initialize the actual operator workload and 51 | # various generators to create inputs for the operator. 52 | op_config = make_op_config(op_name, op_info, run_options["device"]) 53 | 54 | # Generate the actual data for inputs. For operators that require a build 55 | # step, a similar data generation is needed for build config. 56 | input_config = op_info["config"][0]["input"][0] 57 | input_data_gen = op_config.input_data_generator() 58 | (input_args, input_kwargs) = input_data_gen.get_data( 59 | input_config, run_options["device"] 60 | ) 61 | 62 | # Create an OpExecutor to run the actual workload. 63 | op_exe = OpExecutor(op_name, op_config.op, run_options) 64 | 65 | # Run and collect the result metrics. 66 | # "0:0:0" is the run_id, an unique string identifies this benchmark run. 67 | # It may be used to corrleate with other benchmark data (like metrics). 68 | # The run_id format is: "config_id:build_id:input_id" 69 | result = op_exe.run(input_args, input_kwargs, "0:0:0") 70 | 71 | # Loop through and print the metrics. 72 | print("### Benchmark Results ###") 73 | for pass_name, pass_data in result.items(): 74 | print(f"pass: {pass_name}") 75 | for metric_name, metrics in pass_data.items(): 76 | print(metric_name, metrics) 77 | 78 | 79 | if __name__ == "__main__": 80 | main() 81 | -------------------------------------------------------------------------------- /train/compute/python/examples/pytorch/run_op_split_table_batched_embeddings.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from fbgemm_gpu.split_table_batched_embeddings_ops import PoolingMode 3 | 4 | from ...lib import pytorch as lib_pytorch 5 | from ...lib.config import make_op_config 6 | from ...lib.init_helper import load_modules 7 | from ...lib.pytorch.config_util import ( 8 | create_op_args, 9 | create_op_info, 10 | ExecutionPass, 11 | get_benchmark_options, 12 | ) 13 | from ...lib.pytorch.op_executor import OpExecutor 14 | from ...workloads import pytorch as workloads_pytorch 15 | 16 | 17 | def main(): 18 | # Load PyTorch implementations for data generator and operators. 19 | load_modules(lib_pytorch) 20 | 21 | # Load PyTorch operator workloads. 22 | load_modules(workloads_pytorch) 23 | 24 | # Important to set num of threads to 1, some ops are not thread safe and 25 | # also improves measurement stability. 26 | torch.set_num_threads(1) 27 | 28 | op_name = "SplitTableBatchedEmbeddingBagsCodegen" 29 | op_info = create_op_info() 30 | op_info["input_data_generator"] = ( 31 | "SplitTableBatchedEmbeddingBagsCodegenInputDataGenerator" 32 | ) 33 | print(op_info) 34 | 35 | # Get the default benchmark options. 36 | run_options = get_benchmark_options() 37 | 38 | # Create OperatorConfig that initializes the actual operator workload and 39 | # various generators to create inputs for the operator. 40 | op_config = make_op_config(op_name, op_info, run_options["device"]) 41 | 42 | # By default, benchmark will run the forward pass. 43 | # By setting backward (which requires running forward pass), the benchmark 44 | # will run both forward and backward pass. 45 | run_options["pass_type"] = ExecutionPass.BACKWARD 46 | 47 | # Define config parameters required for input data generation and operator building. 48 | num_tables = 1 49 | rows = 228582 50 | dim = 128 51 | batch_size = 512 52 | pooling_factor = 50 53 | weighted = True 54 | weights_precision = "fp16" 55 | optimizer = "exact_row_wise_adagrad" 56 | 57 | # Construct configuration for input data generator. 58 | data_generator_config = create_op_args( 59 | [ 60 | {"type": "int", "name": "num_tables", "value": num_tables}, 61 | {"type": "int", "name": "rows", "value": rows}, 62 | {"type": "int", "name": "dim", "value": dim}, 63 | {"type": "int", "name": "batch_size", "value": batch_size}, 64 | {"type": "int", "name": "pooling_factor", "value": pooling_factor}, 65 | {"type": "bool", "name": "weighted", "value": weighted}, 66 | {"type": "str", "name": "weights_precision", "value": weights_precision}, 67 | ], 68 | {"optimizer": {"type": "str", "value": optimizer}}, 69 | ) 70 | 71 | # Generate the actual data for inputs. 72 | input_data_gen = op_config.input_data_generator() 73 | (input_args, input_kwargs) = input_data_gen.get_data( 74 | data_generator_config, run_options["device"] 75 | ) 76 | 77 | # Construct and initialize the SplitTableBatchedEmbeddingBagsCodegen operator. 78 | op_config.op.build( 79 | num_tables, 80 | rows, 81 | dim, 82 | PoolingMode.SUM, 83 | weighted, 84 | weights_precision, 85 | optimizer, 86 | ) 87 | 88 | # Create an OpExecutor to run the actual workload. 89 | op_exe = OpExecutor(op_name, op_config.op, run_options) 90 | 91 | # Run and collect the result metrics. 92 | result = op_exe.run(input_args, input_kwargs, "0:0:0") 93 | 94 | # Loop through and print the metrics. 95 | print("### Benchmark Results ###") 96 | for pass_name, pass_data in result.items(): 97 | print(f"pass: {pass_name}") 98 | for metric_name, metrics in pass_data.items(): 99 | print(metric_name) 100 | print(metrics) 101 | 102 | 103 | if __name__ == "__main__": 104 | main() 105 | -------------------------------------------------------------------------------- /train/compute/python/lib/__init__.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | __base_version__ = "1.0.0" 4 | 5 | 6 | def __generate_git_param_train_compute_version(): 7 | # git hash 8 | commit_version = "+git" 9 | try: 10 | import git 11 | 12 | repo = git.Repo(search_parent_directories=True) 13 | commit_version = f"{commit_version}.{repo.head.object.hexsha}" 14 | except Exception: 15 | pass 16 | 17 | timestamp = int(time.time()) 18 | commit_version = f"{commit_version}.{timestamp}" 19 | return f"{__base_version__}{commit_version}" 20 | 21 | 22 | def __generate_fbcode_param_train_compute_version(): 23 | # Meta build hash 24 | commit_version = "+fbcode" 25 | try: 26 | from __manifest__ import fbmake 27 | 28 | if fbmake["revision"]: 29 | commit_version = f"{commit_version}.{fbmake['revision']}" 30 | if fbmake["time"]: 31 | commit_version = f"{commit_version}.{fbmake['epochtime']}" 32 | else: 33 | timestamp = int(time.time()) 34 | commit_version = f"{commit_version}.{timestamp}" 35 | except Exception: 36 | commit_version = "+local" 37 | 38 | return f"{__base_version__}{commit_version}" 39 | 40 | 41 | def __get_version(): 42 | # First try to get the version from setup.py generated _version.py file. 43 | try: 44 | from ._version import __param_train_compute_version 45 | 46 | return __param_train_compute_version 47 | except Exception: 48 | pass 49 | # If failed try to get fbcode build version. 50 | return __generate_fbcode_param_train_compute_version() 51 | 52 | 53 | __version__ = __get_version() 54 | -------------------------------------------------------------------------------- /train/compute/python/lib/config.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | from typing import Any, Dict, List, Optional, Type 4 | 5 | from .data import data_generator_map, DataGenerator 6 | from .init_helper import get_logger 7 | from .iterator import config_iterator_map, ConfigIterator, DefaultConfigIterator 8 | from .operator import op_map, OperatorInterface 9 | 10 | 11 | logger = get_logger() 12 | 13 | 14 | class OperatorConfig: 15 | def __init__( 16 | self, name: str, info: dict[str, Any], op: OperatorInterface | None = None 17 | ): 18 | self._name: str = name 19 | self._info: dict[str, Any] = info 20 | self._op: OperatorInterface | None = op 21 | 22 | @property 23 | def name(self) -> str: 24 | return self._name 25 | 26 | @property 27 | def op(self) -> OperatorInterface | None: 28 | return self._op 29 | 30 | @op.setter 31 | def op(self, value: OperatorInterface): 32 | self._op = value 33 | 34 | @property 35 | def info(self) -> dict[str, Any]: 36 | return self._info 37 | 38 | @property 39 | def build_iterator(self) -> type[ConfigIterator]: 40 | return self._build_iterator 41 | 42 | @build_iterator.setter 43 | def build_iterator(self, value: type[ConfigIterator]): 44 | self._build_iterator = value 45 | 46 | @property 47 | def input_iterator(self) -> type[ConfigIterator]: 48 | return self._input_iterator 49 | 50 | @input_iterator.setter 51 | def input_iterator(self, value: type[ConfigIterator]): 52 | self._input_iterator = value 53 | 54 | @property 55 | def build_data_generator(self) -> type[DataGenerator]: 56 | return self._build_data_generator 57 | 58 | @build_data_generator.setter 59 | def build_data_generator(self, value: type[DataGenerator]): 60 | self._build_data_generator = value 61 | 62 | @property 63 | def input_data_generator(self) -> type[DataGenerator]: 64 | return self._input_data_generator 65 | 66 | @input_data_generator.setter 67 | def input_data_generator(self, value: type[DataGenerator]): 68 | self._input_data_generator = value 69 | 70 | 71 | def make_op_config(op_name: str, op_info: dict[str, Any], device: str): 72 | global op_map 73 | if op_name in op_map: 74 | op = op_map[op_name] 75 | op.device = device 76 | else: 77 | op = None 78 | op_config = OperatorConfig(op_name, op_info, op) 79 | 80 | def get(key, table, default): 81 | nonlocal op_info 82 | if key in op_info: 83 | result = op_info[key] 84 | if result and result in table: 85 | return table[result] 86 | return default 87 | 88 | op_config.build_iterator = get( 89 | "build_iterator", config_iterator_map, DefaultConfigIterator 90 | ) 91 | op_config.input_iterator = get( 92 | "input_iterator", config_iterator_map, DefaultConfigIterator 93 | ) 94 | op_config.build_data_generator = get( 95 | "build_data_generator", data_generator_map, None 96 | ) 97 | op_config.input_data_generator = get( 98 | "input_data_generator", data_generator_map, None 99 | ) 100 | 101 | # input_data_generator is required 102 | if not op_config.input_data_generator: 103 | logger.warning( 104 | f"{op_name} has invalid input_data_generator: {op_config.input_data_generator}" 105 | ) 106 | return None 107 | 108 | return op_config 109 | 110 | 111 | class BenchmarkConfig: 112 | """ 113 | BenchmarkConfig stores loaded configuration data. 114 | """ 115 | 116 | def __init__(self, run_options: dict[str, Any]): 117 | self.run_options = run_options 118 | self._op_configs = [] 119 | self.bench_config = None 120 | 121 | def _process_bench_config(self): 122 | for op_name, op_info in self.bench_config.items(): 123 | op_config = make_op_config(op_name, op_info, self.run_options["device"]) 124 | if op_config is not None: 125 | self._op_configs.append(op_config) 126 | 127 | def load_json_file(self, config_file_name: str): 128 | with open(config_file_name) as config_file: 129 | self.bench_config = json.load(config_file) 130 | self._process_bench_config() 131 | 132 | def load_json(self, config_json: str): 133 | self.bench_config = json.loads(config_json) 134 | self._process_bench_config() 135 | 136 | def load(self, config: dict[str, Any]): 137 | self.bench_config = copy.deepcopy(config) 138 | self._process_bench_config() 139 | 140 | @property 141 | def op_configs(self) -> list[OperatorConfig]: 142 | return self._op_configs 143 | 144 | def has_op(self, op: str): 145 | return (op in self.op_configs) and (op in op_map) 146 | -------------------------------------------------------------------------------- /train/compute/python/lib/data.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Dict, Type 3 | 4 | from .init_helper import get_logger 5 | 6 | logger = get_logger() 7 | 8 | 9 | class DataGenerator(metaclass=abc.ABCMeta): 10 | @classmethod 11 | def __subclasshook__(cls, subclass): 12 | return ( 13 | hasattr(subclass, "get_data") 14 | and callable(subclass.get_data) 15 | or NotImplemented 16 | ) 17 | 18 | def __init__(self): 19 | pass 20 | 21 | # Loads arg configurations and generates the arg data for an op. 22 | @abc.abstractmethod 23 | def get_data(self): 24 | raise NotImplementedError 25 | 26 | 27 | def register_data_generator(name: str, data_gen_class: type[DataGenerator]): 28 | global data_generator_map 29 | logger.debug(f"register data generator: {name}") 30 | if name not in data_generator_map: 31 | data_generator_map[name] = data_gen_class 32 | else: 33 | raise ValueError(f"Duplicate data generator registration name: {name}") 34 | 35 | 36 | data_generator_map: dict[str, type[DataGenerator]] = {} 37 | -------------------------------------------------------------------------------- /train/compute/python/lib/generator.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Set 2 | 3 | 4 | def full_range(a: int, b: int, s: int = 1): 5 | """ 6 | Returns inclusive range: a <= x <= b, by step of s 7 | """ 8 | return range(a, b + 1, s) 9 | 10 | 11 | # Repeatable iterator for lists. 12 | class IterableList: 13 | def __init__(self, items: list[Any]): 14 | self.items = items 15 | 16 | def __iter__(self): 17 | return self.Iterator(self.items) 18 | 19 | class Iterator: 20 | def __init__(self, items: list[Any]): 21 | self.iter = iter(items) 22 | 23 | def __iter__(self): 24 | return self 25 | 26 | def __next__(self): 27 | return next(self.iter) 28 | 29 | 30 | class ListProduct: 31 | """ 32 | ListProduct takes a list of repeatable iterables (like range()), and 33 | generates the Cartesian product of the iterables. 34 | 35 | Important: 36 | The list returned will be mutated in place for each iteration. If the 37 | user code wants to keep a copy or modify the generated list, it should 38 | make a copy of the returned result using `copy.deepcopy(generated_list)`. 39 | Example: 40 | ``` 41 | result = [] 42 | for gen_list in ListProduct(iter_list): 43 | result.append(copy.deepcopy(gen_list)) 44 | ``` 45 | 46 | This interface wraps the Iterator so that a new iterator is created once 47 | the generator is exhausted. This allows repeatable iterations, i.e. 48 | 49 | iter_list_1 = [range(2, 6, 2), range(1, 3, 1), range(2, 4, 1)] 50 | iter_list_2 = [range(2, 6, 2), range(1, 3, 1)] 51 | prod = ListProduct([ListProduct(iter_list_1), ListProduct(iter_list_2)]) 52 | for i in prod: 53 | print("i",i) 54 | 55 | for i in prod: 56 | print("i",i) 57 | """ 58 | 59 | def __init__(self, iter_list: list[Any]): 60 | self.iter_list: list[Any] = iter_list 61 | 62 | def __iter__(self): 63 | return self.Iterator(self.iter_list, [None] * len(self.iter_list), 0) 64 | 65 | class Iterator: 66 | def __init__(self, iter_list: list[Any], val_list: list[Any], idx: int): 67 | self.generator = self._generate_next(iter_list, val_list, idx) 68 | 69 | def __iter__(self): 70 | return self 71 | 72 | def _generate_next(self, iter_list: list[Any], val_list: list[Any], idx: int): 73 | if iter_list: 74 | # If current item is iterable, loop through and recursive to next 75 | # item in the list 76 | if type(iter_list[0]) in iterable_types: 77 | for i in iter_list[0]: 78 | val_list[idx] = i 79 | if len(iter_list) == 1: 80 | yield val_list 81 | else: 82 | yield from self._generate_next( 83 | iter_list[1:], val_list, idx + 1 84 | ) 85 | # If current item is not iterable, just assign and recursive to next 86 | # item in the list 87 | else: 88 | val_list[idx] = iter_list[0] 89 | if len(iter_list) == 1: 90 | yield val_list 91 | else: 92 | yield from self._generate_next(iter_list[1:], val_list, idx + 1) 93 | else: 94 | yield iter_list 95 | 96 | def __next__(self): 97 | return next(self.generator) 98 | 99 | 100 | class TableProduct: 101 | def __init__(self, table: dict[Any, Any]): 102 | self.table: dict[Any, Any] = table 103 | self.result: dict[Any, Any] = {} 104 | 105 | def __iter__(self): 106 | iterable_keys = [] 107 | self.result = dict.fromkeys(self.table) 108 | # check which key/val has iterables, copy non iterable values to 109 | # result table 110 | for key, val in self.table.items(): 111 | # Only works with new classes with __iter__ interface. 112 | # If needed use iter() to check, but more expensive. 113 | if type(val) in iterable_types: 114 | iterable_keys.append(key) 115 | else: 116 | self.result[key] = val 117 | return self.Iterator(self.table, iterable_keys, self.result, 0) 118 | 119 | class Iterator: 120 | def __init__( 121 | self, 122 | table: dict[Any, Any], 123 | iterable_keys: list[Any], 124 | result: dict[Any, Any], 125 | idx: int, 126 | ): 127 | self.generator = self._generate_next(table, iterable_keys, result, idx) 128 | 129 | def __iter__(self): 130 | return self 131 | 132 | def _generate_next( 133 | self, 134 | table: dict[Any, Any], 135 | iterable_keys: list[Any], 136 | result: dict[Any, Any], 137 | idx: int, 138 | ): 139 | if table: 140 | if not iterable_keys: 141 | yield table 142 | else: 143 | for val in table[iterable_keys[0]]: 144 | result[iterable_keys[0]] = val 145 | if len(iterable_keys) == 1: 146 | yield result 147 | else: 148 | yield from self._generate_next( 149 | table, iterable_keys[1:], result, idx + 1 150 | ) 151 | else: 152 | yield table 153 | 154 | def __next__(self): 155 | return next(self.generator) 156 | 157 | 158 | iterable_types: set[Any] = {range, IterableList, ListProduct, TableProduct} 159 | -------------------------------------------------------------------------------- /train/compute/python/lib/init_helper.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import logging 3 | import pkgutil 4 | 5 | 6 | _logger = None 7 | _logger_stream_handler = None 8 | 9 | 10 | def get_logger(): 11 | global _logger 12 | if _logger: 13 | return _logger 14 | else: 15 | return init_logging(logging.INFO) 16 | 17 | 18 | def init_logging(log_level): 19 | global _logger 20 | global _logger_stream_handler 21 | if log_level is logging.DEBUG: 22 | FORMAT = "[%(asctime)s] %(process)d %(filename)s:%(lineno)-3d [%(levelname)s]: %(message)s" 23 | else: 24 | FORMAT = "[%(asctime)s] %(process)d [%(levelname)s]: %(message)s" 25 | _logger = logging.getLogger("param_bench") 26 | _logger.setLevel(log_level) 27 | # Reset the stream handlers to avoid multiple outputs. 28 | _logger.handlers.clear() 29 | # Do not use parent logger to avoid duplicate messages. 30 | _logger.propagate = False 31 | _logger_stream_handler = logging.StreamHandler() 32 | _logger_stream_handler.setLevel(log_level) 33 | formatter = logging.Formatter(FORMAT) 34 | _logger_stream_handler.setFormatter(formatter) 35 | _logger.addHandler(_logger_stream_handler) 36 | return _logger 37 | 38 | 39 | logger = get_logger() 40 | 41 | 42 | def load_modules(package): 43 | """ 44 | Given a package, load/import all the modules in that package. 45 | See https://packaging.python.org/guides/creating-and-discovering-plugins/ 46 | """ 47 | modules = pkgutil.iter_modules(package.__path__, package.__name__ + ".") 48 | for _, name, _ in modules: 49 | logger.debug(f"loading module: {name}") 50 | try: 51 | importlib.import_module(name) 52 | except ModuleNotFoundError as error: 53 | logger.warning( 54 | f"failed to import module: {name}. ModuleNotFoundError: {error}" 55 | ) 56 | 57 | 58 | def load_package(package) -> bool: 59 | """ 60 | Try to load third-party modules, return false if failed. 61 | """ 62 | logger.debug(f"loading package: {package}") 63 | try: 64 | importlib.import_module(package) 65 | except ModuleNotFoundError as error: 66 | logger.warning( 67 | f"failed to import package: {package}. ModuleNotFoundError: {error}" 68 | ) 69 | return False 70 | return True 71 | -------------------------------------------------------------------------------- /train/compute/python/lib/operator.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Dict, Type 3 | 4 | from .init_helper import get_logger 5 | 6 | logger = get_logger() 7 | 8 | 9 | class OperatorInterface(metaclass=abc.ABCMeta): 10 | """ 11 | The OperatorInterface assumes the following operations: 12 | 13 | - An operator may require a build/initialization step. 14 | - Forward operation is always required. 15 | - Backward may require a gradient input, and create_grad should not be part 16 | of the benchmark measurement. 17 | """ 18 | 19 | @classmethod 20 | def __subclasshook__(cls, subclass): 21 | return ( 22 | hasattr(subclass, "forward") 23 | and callable(subclass.forward) 24 | or NotImplemented 25 | ) 26 | 27 | def __init__(self): 28 | self.device = None 29 | 30 | # Construct and initialize the operator. 31 | def build(self, *args, **kwargs): 32 | pass 33 | 34 | # Reset any state and remove allocated resources. 35 | def cleanup(self): 36 | pass 37 | 38 | @abc.abstractmethod 39 | def forward(self, *args, **kwargs): 40 | raise NotImplementedError 41 | 42 | def create_grad(self): 43 | raise NotImplementedError 44 | 45 | def backward(self): 46 | raise NotImplementedError 47 | 48 | 49 | def register_operator(name: str, operator_class: type[OperatorInterface]): 50 | global op_map 51 | logger.debug(f"register op: {name}") 52 | if name not in op_map: 53 | op_map[name] = operator_class 54 | else: 55 | raise ValueError(f"Duplicate operator registration name: {name}") 56 | 57 | 58 | def register_operators(op_dict: dict[str, type[OperatorInterface]]): 59 | global op_map 60 | for name, operator_class in op_dict.items(): 61 | logger.debug(f"register op: {name}") 62 | if name not in op_map: 63 | op_map[name] = operator_class 64 | else: 65 | raise ValueError(f"Duplicate operator registration name: {name}") 66 | 67 | 68 | # Global operator registry, a mapping of name to operator object 69 | op_map: dict[str, type[OperatorInterface]] = {} 70 | -------------------------------------------------------------------------------- /train/compute/python/lib/pytorch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/param/365dc0409e6f66a1823199a62553a1ed558693b6/train/compute/python/lib/pytorch/__init__.py -------------------------------------------------------------------------------- /train/compute/python/lib/pytorch/benchmark.py: -------------------------------------------------------------------------------- 1 | from ..init_helper import get_logger 2 | 3 | logger = get_logger() 4 | 5 | from typing import List, Type 6 | 7 | from ..config import BenchmarkConfig, OperatorConfig 8 | from .build_executor import BuildExecutor, OpBuildExecutor, StopBenchmarkException 9 | from .config_util import init_pytorch 10 | from .operator_impl import TorchScriptOp 11 | 12 | UNSUPPORTED_OPS = [ 13 | "aten::record_stream", 14 | "aten::to", 15 | "aten::select", 16 | "aten::item", 17 | "aten::cat", 18 | "aten::split_with_sizes", 19 | ] 20 | 21 | 22 | class Benchmark: 23 | """ 24 | Benchmark is the high level interface to collect metrics for a large number 25 | of workloads with many configurations. This class does not execute the 26 | benchmark directly. It takes a BuildExecutor class to create an executor. 27 | The build executor a build config and operator inputs. The actual execution 28 | of the workloads are implemented in OpExecutor. 29 | 30 | The reason to have this flexibility is to provide a library interface that 31 | allows use cases where a small number of build and input configurations can 32 | be created on the fly and the user simply wants to get the metrics directly 33 | from OpExecutor. 34 | 35 | In other cases, a derived class of BuildExecutor may implement a parallel 36 | (multiprocess) way to run the benchmarks, or run additional tool chains. 37 | All this can be done by implementing a new BuildExecutor and without 38 | modifying the benchmark logics in the OpExecutor. 39 | 40 | bench_config: contains all the benchmark configurations for the workloads. 41 | 42 | build_executor: a BuildExecutor that takes a concrete build config and operator 43 | inputs, op configuration, to run and collect benchmark metrics. 44 | """ 45 | 46 | def __init__( 47 | self, bench_config: BenchmarkConfig, build_executor: type[BuildExecutor] 48 | ): 49 | init_pytorch(bench_config.run_options) 50 | self.bench_config = bench_config 51 | self.build_executor = build_executor 52 | self.run_options = bench_config.run_options 53 | 54 | # Construct a BuildExecutor 55 | self.build_executor = build_executor(self.run_options) 56 | self.build_executor.set_resume_op_run_id(self.run_options["resume_op_run_id"]) 57 | self.build_executor.set_stop_op_run_id(self.run_options["stop_op_run_id"]) 58 | 59 | def run(self): 60 | try: 61 | for op_config in self.bench_config.op_configs: 62 | if op_config.op is None: 63 | if ( 64 | op_config.name not in UNSUPPORTED_OPS 65 | and op_config.name.startswith("aten::") 66 | ): 67 | logger.info(f"register torchscript op: {op_config.name}") 68 | op_config.op = TorchScriptOp(op_config.name) 69 | op_config.op.device = self.run_options["device"] 70 | 71 | if op_config.op is not None: 72 | self.run_op(op_config) 73 | except StopBenchmarkException as stop_event: 74 | logger.info(stop_event) 75 | 76 | def run_op(self, op_config: OperatorConfig) -> list[str]: 77 | logger.info(f"### op: {op_config.name}") 78 | config_id = 0 79 | for config in op_config.info["config"]: 80 | op_run_id = str(config_id) 81 | logger.info(f"config_id: [{op_run_id}]") 82 | if "input" not in config: 83 | logger.error( 84 | f"{op_config.name} has no input configurations defined, skipped." 85 | ) 86 | return 87 | 88 | generate_build_config = None 89 | if op_config.build_iterator and "build" in config: 90 | logger.debug(f"build_config: {config['build']}") 91 | if config["build"]: 92 | generate_build_config = op_config.build_iterator( 93 | config, "build", self.run_options["device"] 94 | ) 95 | 96 | build_input_config = {} 97 | if generate_build_config: 98 | logger.debug("generating build config") 99 | for build_id, build_config in generate_build_config: 100 | logger.info(f"build_id: [{build_id}]") 101 | logger.debug(f"build_config: {build_config}") 102 | op_run_id = f"{op_run_id}|{build_id}" 103 | build_input_config["build"] = build_config 104 | build_input_config["input"] = config["input"] 105 | self.build_executor.run(op_config, build_input_config, op_run_id) 106 | else: 107 | build_id = "0" 108 | build_config = config.get("build", None) 109 | logger.info(f"build_id: [{build_id}]") 110 | logger.debug(f"build_config: {build_config}") 111 | op_run_id = f"{op_run_id}|{build_id}" 112 | build_input_config["build"] = build_config 113 | build_input_config["input"] = config["input"] 114 | self.build_executor.run(op_config, build_input_config, op_run_id) 115 | 116 | config_id += 1 117 | 118 | 119 | def make_default_benchmark(bench_config: BenchmarkConfig): 120 | return Benchmark(bench_config, OpBuildExecutor) 121 | -------------------------------------------------------------------------------- /train/compute/python/lib/pytorch/config_util.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import enum 3 | import os 4 | import platform 5 | import random 6 | import socket 7 | from typing import Any, Dict, List 8 | 9 | import torch 10 | from torch.utils.collect_env import get_nvidia_driver_version, run as run_cmd 11 | 12 | from ...lib import __version__ 13 | 14 | 15 | @enum.unique 16 | class ExecutionPass(enum.Enum): 17 | # Forward pass will always run (also required for backward pass). 18 | FORWARD = "forward" 19 | 20 | # Run backward pass in addition to forward pass. 21 | BACKWARD = "backward" 22 | 23 | 24 | @enum.unique 25 | class OpExecutionMode(enum.Enum): 26 | # Run operator seprately and clear cache between each call. 27 | DISCRETE = "discrete" 28 | 29 | # Run operator back to back without clear cache, etc. 30 | CONTINUOUS = "continuous" 31 | 32 | # Run operator back to back but record indivisual events. 33 | CONTINUOUS_EVENTS = "continuous_events" 34 | 35 | 36 | def get_op_run_id(op_name: str, run_id: str) -> str: 37 | return f"{op_name}:{run_id}" 38 | 39 | 40 | def get_benchmark_options() -> dict[str, Any]: 41 | options = { 42 | "device": "cpu", 43 | "pass_type": ExecutionPass.FORWARD, 44 | "warmup": 1, 45 | "iteration": 1, 46 | "op_exec_mode": OpExecutionMode.DISCRETE, 47 | "cuda_l2_cache": False, 48 | "time_unit": "millisecond", 49 | "out_file_prefix": None, 50 | "out_stream": None, 51 | "run_ncu": False, 52 | "ncu_bin": "/usr/local/NVIDIA-Nsight-Compute-2021.2/ncu", 53 | "ncu_args": "", 54 | "ncu_warmup": 5, 55 | "ncu_iteration": 1, 56 | "run_nsys": False, 57 | "nsys_bin": "/opt/nvidia/nsight-systems/2021.4.1/bin/nsys", 58 | "nsys_args": "", 59 | "nsys_warmup": 5, 60 | "nsys_iteration": 10, 61 | "run_batch_size": 50, 62 | "batch_cuda_device": 1, 63 | "batch_cmd": "python -m param_bench.train.compute.python.pytorch.run_batch", 64 | "resume_op_run_id": None, 65 | "stop_op_run_id": None, 66 | } 67 | 68 | return options 69 | 70 | 71 | def create_bench_config(name: str) -> dict[str, Any]: 72 | return {name: create_op_info()} 73 | 74 | 75 | def create_op_info() -> dict[str, Any]: 76 | return { 77 | "build_iterator": None, 78 | "input_iterator": None, 79 | "build_data_generator": None, 80 | "input_data_generator": "PyTorch:DefaultDataGenerator", 81 | "config": [{"build": [], "input": []}], 82 | } 83 | 84 | 85 | def create_op_args(args: list[Any], kwargs: dict[str, Any]) -> dict[str, Any]: 86 | return {"args": args, "kwargs": kwargs} 87 | 88 | 89 | _pytorch_type: dict[str, Any] = { 90 | "int": {"type": "int", "value": None}, 91 | "int_range": {"type": "int", "value_range": None}, 92 | "long": {"type": "long", "value": None}, 93 | "long_range": {"type": "long", "value_range": None}, 94 | "float": {"type": "float", "value": None}, 95 | "float_range": {"type": "float", "value_range": None}, 96 | "double": {"type": "double", "value": None}, 97 | "double_range": {"type": "double", "value_range": None}, 98 | "bool": {"type": "bool", "value": None}, 99 | "device": {"type": "device", "value": None}, 100 | "str": {"type": "str", "value": None}, 101 | "genericlist": {"type": "genericlist", "value": None}, 102 | "tuple": {"type": "tuple", "value": None}, 103 | "tensor": {"type": "tensor", "dtype": "float", "shape": None}, 104 | } 105 | 106 | 107 | def create_type(type) -> dict[str, Any]: 108 | return copy.deepcopy(_pytorch_type[type]) 109 | 110 | 111 | def get_sys_info(): 112 | cuda_available = torch.cuda.is_available() 113 | cuda_info = {} 114 | if cuda_available: 115 | cuda_device_id = torch.cuda.current_device() 116 | cuda_device_property = torch.cuda.get_device_properties(cuda_device_id) 117 | cuda_info = { 118 | "cuda": torch.version.cuda, 119 | "cuda_device_driver": get_nvidia_driver_version(run_cmd), 120 | "cuda_gencode": torch.cuda.get_gencode_flags(), 121 | "cuda_device_id": cuda_device_id, 122 | "cuda_device_name": torch.cuda.get_device_name(), 123 | "cuda_device_property": cuda_device_property, 124 | "cudnn": torch.backends.cudnn.version(), 125 | "cudnn_enabled": torch.backends.cudnn.enabled, 126 | } 127 | 128 | return { 129 | "hostname": socket.gethostname(), 130 | "pid": os.getpid(), 131 | "cwd": os.getcwd(), 132 | "python_version": platform.python_version(), 133 | "param_train_compute_version": __version__, 134 | "cuda_available": cuda_available, 135 | **cuda_info, 136 | "pytorch_version": torch.__version__, 137 | "pytorch_debug_build": torch.version.debug, 138 | "pytorch_build_config": torch._C._show_config(), 139 | } 140 | 141 | 142 | def init_pytorch(run_options: dict[str, Any]): 143 | # We don't want too many threads for stable benchmarks 144 | torch.set_num_threads(1) 145 | 146 | # Fix random number generator seeds. 147 | torch.manual_seed(0) 148 | random.seed(0) 149 | -------------------------------------------------------------------------------- /train/compute/python/lib/pytorch/cuda_util.py: -------------------------------------------------------------------------------- 1 | import gc 2 | 3 | import torch 4 | 5 | from ..init_helper import get_logger 6 | 7 | logger = get_logger() 8 | 9 | 10 | def log_cuda_memory_usage(): 11 | cuda_allocated = torch.cuda.memory_allocated() / 1048576 12 | cuda_reserved = torch.cuda.memory_reserved() / 1048576 13 | logger.info( 14 | f"CUDA memory allocated = {cuda_allocated:.3f} MB, reserved = {cuda_reserved:.3f} MB" 15 | ) 16 | 17 | 18 | def free_torch_cuda_memory(): 19 | gc.collect() 20 | torch.cuda.empty_cache() 21 | -------------------------------------------------------------------------------- /train/compute/python/lib/pytorch/data_impl.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import random 3 | from collections.abc import Callable 4 | from typing import Any, Dict, List, Set 5 | 6 | import torch 7 | 8 | from ..data import DataGenerator, register_data_generator 9 | from ..init_helper import get_logger 10 | 11 | logger = get_logger() 12 | 13 | pytorch_int_dtype_map: dict[str, torch.dtype] = { 14 | "uint8": torch.uint8, 15 | "int8": torch.int8, 16 | "int16": torch.int16, 17 | "int": torch.int32, 18 | "long": torch.int64, 19 | } 20 | pytorch_float_dtype_map: dict[str, torch.dtype] = { 21 | "float": torch.float32, 22 | "double": torch.float64, 23 | "float16": torch.float16, 24 | "bfloat16": torch.bfloat16, 25 | } 26 | pytorch_dtype_map: dict[str, torch.dtype] = { 27 | **pytorch_int_dtype_map, 28 | **pytorch_float_dtype_map, 29 | "bool": torch.bool, 30 | } 31 | 32 | 33 | def materialize_arg(arg: dict[str, Any], device: str) -> Any: 34 | """ 35 | Given an arg configuration, materialize the test data for that arg. 36 | """ 37 | 38 | def create_tensor(attr: dict[str, Any]): 39 | shape = attr["shape"] 40 | requires_grad = attr.get("requires_grad", True) 41 | if len(shape) > 0: 42 | if attr["dtype"] in pytorch_float_dtype_map: 43 | return torch.rand( 44 | *shape, 45 | dtype=pytorch_dtype_map[attr["dtype"]], 46 | requires_grad=requires_grad, 47 | device=torch.device(device), 48 | ) 49 | elif attr["dtype"] in pytorch_int_dtype_map: 50 | return torch.randint( 51 | -10, 52 | 10, 53 | tuple(shape), 54 | dtype=pytorch_dtype_map[attr["dtype"]], 55 | requires_grad=requires_grad, 56 | device=torch.device(device), 57 | ) 58 | elif attr["dtype"] == "bool": 59 | return ( 60 | torch.rand( 61 | *shape, 62 | dtype=pytorch_dtype_map["float"], 63 | requires_grad=requires_grad, 64 | device=torch.device(device), 65 | ) 66 | < 0.5 67 | ) 68 | # Single value 69 | else: 70 | return torch.tensor( 71 | random.uniform(-10.0, 10.0), 72 | dtype=pytorch_dtype_map[attr["dtype"]], 73 | requires_grad=requires_grad, 74 | device=torch.device(device), 75 | ) 76 | 77 | def create_float(attr: dict[str, Any]): 78 | if "value" in attr: 79 | return attr["value"] 80 | return random.uniform(attr["value_range"][0], attr["value_range"][1]) 81 | 82 | def create_int(attr: dict[str, Any]): 83 | # check "value" key exists, attr["value"] = 0 could be eval to False 84 | if "value" in attr: 85 | return attr["value"] 86 | return random.randint(attr["value_range"][0], attr["value_range"][1]) 87 | 88 | def create_str(attr: dict[str, Any]): 89 | # check "value" key exists, attr["value"] = 0 could be eval to False 90 | if "value" in attr: 91 | return attr["value"] 92 | return "" 93 | 94 | def create_bool(attr: dict[str, Any]): 95 | return attr["value"] 96 | 97 | def create_none(attr: dict[str, Any]): 98 | return None 99 | 100 | def create_device(attr: dict[str, Any]): 101 | return torch.device(attr["value"]) 102 | 103 | def create_genericlist(attr: list[Any]): 104 | result = [] 105 | for item in attr["value"]: 106 | result.append(arg_factory[item["type"]](item)) 107 | return result 108 | 109 | def create_tuple(attr: list[Any]): 110 | result = create_genericlist(attr) 111 | return tuple(result) 112 | 113 | # Map of argument types to the create methods. 114 | arg_factory: dict[str, Callable] = { 115 | "tensor": create_tensor, 116 | "float": create_float, 117 | "double": create_float, 118 | "int": create_int, 119 | "long": create_int, 120 | "none": create_none, 121 | "bool": create_bool, 122 | "device": create_device, 123 | "str": create_str, 124 | "genericlist": create_genericlist, 125 | "tuple": create_tuple, 126 | } 127 | return arg_factory[arg["type"]](arg) 128 | 129 | 130 | # DefaultDataGenerator 131 | class DefaultDataGenerator(DataGenerator): 132 | def __init__(self, cache: bool = False): 133 | super().__init__() 134 | # keep track/cache last arg_config so we only generate data for 135 | # args that's different from previous iteration. 136 | self.cache = cache 137 | self.prev_config = None 138 | self.op_args = [] 139 | self.op_kwargs = {} 140 | 141 | def _find_updates(self, config: dict[str, Any]): 142 | if not self.prev_config: 143 | return (None, None) 144 | arg_updates = set() 145 | kwarg_updates = set() 146 | if "args" in config: 147 | for i, vals in enumerate(zip(self.prev_config["args"], config["args"])): 148 | if vals[0] != vals[1]: 149 | arg_updates.add(i) 150 | if "kwargs" in config: 151 | for key in self.prev_config["kwargs"]: 152 | if self.prev_config["kwargs"][key] != config["kwargs"][key]: 153 | kwarg_updates.add(key) 154 | 155 | logger.debug(f" prev: {self.prev_config}") 156 | logger.debug(f" curr: {config}") 157 | logger.debug(f" updt: {arg_updates} {kwarg_updates}") 158 | return (arg_updates, kwarg_updates) 159 | 160 | def _generate_data( 161 | self, 162 | config: dict[str, Any], 163 | device: str, 164 | op_args: list[Any], # potentially cached container 165 | op_kwargs: dict[str, Any], # potentially cached container 166 | arg_updates: set[Any], 167 | kwarg_updates: set[Any], 168 | ): 169 | # initialize positional args array if empty (not cached). 170 | if len(op_args) == 0: 171 | op_args = [None] * len(config["args"]) 172 | if "args" in config: 173 | for i, arg in enumerate(config["args"]): 174 | if arg_updates: 175 | if i in arg_updates: 176 | op_args[i] = materialize_arg(arg, device) 177 | else: 178 | op_args[i] = materialize_arg(arg, device) 179 | 180 | if "kwargs" in config: 181 | for key, arg in config["kwargs"].items(): 182 | if kwarg_updates: 183 | if key in kwarg_updates: 184 | op_kwargs[key] = materialize_arg(arg, device) 185 | else: 186 | op_kwargs[key] = materialize_arg(arg, device) 187 | return (op_args, op_kwargs) 188 | 189 | def get_data(self, config: dict[str, Any], device: str): 190 | if not config: 191 | # No configs, just return empty args. 192 | return ([], {}) 193 | elif self.cache: 194 | # find the arg config that changed from previous iteration 195 | arg_updates, kwarg_updates = self._find_updates(config) 196 | # cache arg configs for next iteration to compare. 197 | self.prev_config = copy.deepcopy(config) 198 | return self._generate_data( 199 | config, device, self.op_args, self.op_kwargs, arg_updates, kwarg_updates 200 | ) 201 | else: 202 | op_args = [] 203 | op_kwargs = {} 204 | return self._generate_data(config, device, op_args, op_kwargs, None, None) 205 | 206 | 207 | register_data_generator("PyTorch:DefaultDataGenerator", DefaultDataGenerator) 208 | -------------------------------------------------------------------------------- /train/compute/python/lib/pytorch/operator_impl.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | from typing import List 3 | 4 | from ..init_helper import get_logger 5 | 6 | logger = get_logger() 7 | 8 | import re 9 | 10 | import torch 11 | 12 | from ..operator import OperatorInterface 13 | 14 | 15 | class UnaryOp(OperatorInterface): 16 | """ 17 | UnaryOp is called in the form of tensor_obj.op(args), we convert it 18 | to a regular function call with "getattr(tensor_obj, op)(args)". So the 19 | first arg is assumed to be the `tensor_obj`. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | func_name: str, 25 | ): 26 | super().__init__() 27 | self.func_name: str = func_name 28 | self.fwd_out: torch.tensor = None 29 | self.grad_in: torch.tensor = None 30 | 31 | def forward(self, *args, **kwargs): 32 | # The first arg is assume to be the inplace value, pass on the rest of 33 | # the args to the callable. 34 | # Unary op also does not support backward() because they are in-place. 35 | with torch.no_grad(): 36 | getattr(args[0], self.func_name)(*args[1:], **kwargs) 37 | 38 | def create_grad(self): 39 | pass 40 | 41 | def backward(self): 42 | pass 43 | 44 | 45 | class CallableOp(OperatorInterface): 46 | """ 47 | Callable ops are ops can be called in the form of op(*args, **kwargs) 48 | """ 49 | 50 | def __init__( 51 | self, 52 | func: Callable, 53 | ): 54 | super().__init__() 55 | self.func: Callable = func 56 | self.fwd_out: torch.tensor = None 57 | self.grad_in = None 58 | 59 | def cleanup(self): 60 | self.fwd_out = None 61 | self.grad_in = None 62 | 63 | def forward(self, *args, **kwargs): 64 | self.fwd_out = self.func(*args, **kwargs) 65 | return self.fwd_out 66 | 67 | def create_grad(self): 68 | if not self.fwd_out.is_leaf: 69 | self.grad_in = torch.ones_like(self.fwd_out) 70 | else: 71 | logger.debug( 72 | f"{self.constructor.__name__}: skipping create_grad() due to forward result is leaf tensor." 73 | ) 74 | 75 | def backward(self): 76 | if not self.fwd_out.is_leaf: 77 | self.fwd_out.backward(self.grad_in) 78 | else: 79 | logger.debug( 80 | f"{self.constructor.__name__}: skipping backward() due to forward result is leaf tensor." 81 | ) 82 | 83 | 84 | class BuildableOp(OperatorInterface): 85 | """ 86 | BuildableOp are ops needs to be constructed first, before running with inputs. 87 | """ 88 | 89 | def __init__( 90 | self, 91 | constructor: Callable, 92 | ): 93 | super().__init__() 94 | self.constructor: Callable = constructor 95 | self.func: Callable = None 96 | self.fwd_out: torch.tensor = None 97 | self.grad_in = None 98 | 99 | # Construct and initialize the operator. 100 | def build(self, *args, **kwargs): 101 | # Use `to` to make sure weights are on device. 102 | self.func = self.constructor(*args, **kwargs).to(torch.device(self.device)) 103 | 104 | def cleanup(self): 105 | self.fwd_out = None 106 | self.grad_in = None 107 | 108 | def forward(self, *args, **kwargs): 109 | self.fwd_out = self.func(*args, **kwargs) 110 | return self.fwd_out 111 | 112 | def create_grad(self): 113 | if not self.fwd_out.is_leaf: 114 | self.grad_in = torch.ones_like(self.fwd_out) 115 | else: 116 | logger.debug( 117 | f"{self.constructor.__name__}: skipping create_grad() due to forward result is leaf tensor." 118 | ) 119 | 120 | def backward(self): 121 | if not self.fwd_out.is_leaf: 122 | self.fwd_out.backward(self.grad_in) 123 | else: 124 | logger.debug( 125 | f"{self.constructor.__name__}: skipping backward() due to forward result is leaf tensor." 126 | ) 127 | 128 | 129 | class TorchScriptOp(OperatorInterface): 130 | """ 131 | TorchScriptOp generates a graph IR that runs a specific PyTorch function in 132 | the SSA form. 133 | """ 134 | 135 | def __init__( 136 | self, 137 | func_name: str, 138 | ): 139 | super().__init__() 140 | self.func_name: str = func_name 141 | self.func: Callable = None 142 | self.fwd_out: torch.tensor = None 143 | self.grad_in: torch.tensor = None 144 | 145 | def build(self, op_schema: str): 146 | """ 147 | Because TorchScript is in SSA form, we expect at least one element in 148 | types for the output. An example is: 149 | ``` 150 | graph(%0 : Tensor, 151 | %1 : Tensor, 152 | %2 : int): 153 | %3 : Tensor = aten::add(%0, %1, %2) 154 | return (%3) 155 | ``` 156 | """ 157 | 158 | def _extract_types(types_str: str): 159 | # split into args in the form of "type_name var_name". 160 | types = [item for item in types_str.split(",")] 161 | # separate betwen type and var name, keep types, skip *. 162 | types = [item.strip().split(" ")[0] for item in types if "*" not in item] 163 | # remove list length, e.g. int[2] -> int[]. 164 | types = [re.sub(r"\[[0-9]\]", "[]", t) for t in types] 165 | # remove elem type for tensors, e.g. Tensor(float) -> Tensor 166 | var_types = [item if "Tensor" not in item else "Tensor" for item in types] 167 | return var_types 168 | 169 | assert ( 170 | op_schema 171 | ), f"TorchScriptOp {self.func_name} should have at non-empty op schema." 172 | 173 | func_name, func_signature = op_schema.split("(", 1) 174 | arg_str, output_str = func_signature.split("->", 1) 175 | arg_str = arg_str.strip("() ") 176 | output_str = output_str.strip("() ") 177 | arg_types = _extract_types(arg_str) 178 | output_types = _extract_types(output_str) 179 | 180 | graph_args = [] 181 | func_args = [] 182 | 183 | func_schema = torch._C.parse_schema(op_schema) 184 | register_id = 0 185 | for data_type in arg_types: 186 | graph_args.append(f"%{register_id} : {data_type}") 187 | func_args.append(f"%{register_id}") 188 | register_id += 1 189 | 190 | func_outputs = [] 191 | func_output_vars = [] 192 | func_output_types = [] 193 | for data_type in output_types: 194 | func_outputs.append(f"%{register_id} : {data_type}") 195 | func_output_vars.append(f"%{register_id}") 196 | func_output_types.append(data_type) 197 | output_var = f"%{register_id}" 198 | register_id += 1 199 | return_construct = "" 200 | if len(func_outputs) > 1: 201 | return_construct = f"%{register_id}: ({','.join(func_output_types)}) = prim::TupleConstruct({','.join(func_output_vars)})" 202 | output_var = f"%{register_id}" 203 | actual_func_name = func_schema.name 204 | 205 | ts_ir = f""" 206 | graph({",".join(graph_args)}): 207 | {",".join(func_outputs)} = {actual_func_name}({",".join(func_args)}) 208 | {return_construct} 209 | return ({output_var}) 210 | """ 211 | ts_graph = torch._C.parse_ir(ts_ir) 212 | logger.debug(f"{self.func_name} TorchScript IR Graph: \n{ts_graph}") 213 | cu = torch._C.CompilationUnit() 214 | self.func = cu.create_function(self.func_name, ts_graph) 215 | 216 | def cleanup(self): 217 | self.fwd_out = None 218 | self.grad_in = None 219 | 220 | def forward(self, *args, **kwargs): 221 | self.fwd_out = self.func(*args, **kwargs) 222 | return self.fwd_out 223 | 224 | def create_grad(self): 225 | if not self.fwd_out.is_leaf: 226 | self.grad_in = torch.ones_like(self.fwd_out) 227 | else: 228 | logger.debug( 229 | f"{self.func_name}: skipping create_grad() due to forward result is leaf tensor." 230 | ) 231 | 232 | def backward(self): 233 | if not self.fwd_out.is_leaf: 234 | self.fwd_out.backward(self.grad_in) 235 | else: 236 | logger.debug( 237 | f"{self.func_name}: skipping backward() due to forward result is leaf tensor." 238 | ) 239 | -------------------------------------------------------------------------------- /train/compute/python/lib/pytorch/timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | 5 | 6 | # Timer 7 | class Timer: 8 | def __init__(self, device: str): 9 | self.device: str = device 10 | if self.device is None: 11 | self.torch_device = None 12 | else: 13 | self.torch_device = torch.device(self.device) 14 | self.start_time: float = 0 15 | self.end_time: float = 0 16 | 17 | def start(self): 18 | if self.device.startswith("cuda"): 19 | torch.cuda.synchronize(self.torch_device) 20 | self.start_time = time.perf_counter_ns() 21 | 22 | def stop(self): 23 | if self.device.startswith("cuda"): 24 | torch.cuda.synchronize(self.torch_device) 25 | self.end_time = time.perf_counter_ns() 26 | 27 | # Return result in milliseconds. 28 | def elapsed_time_ms(self) -> float: 29 | return (self.end_time - self.start_time) / 1e6 30 | 31 | def elapsed_time_sec(self) -> float: 32 | return (self.end_time - self.start_time) / 1e9 33 | -------------------------------------------------------------------------------- /train/compute/python/pytorch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/param/365dc0409e6f66a1823199a62553a1ed558693b6/train/compute/python/pytorch/__init__.py -------------------------------------------------------------------------------- /train/compute/python/pytorch/run_batch.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from ..lib.init_helper import init_logging, load_modules 4 | 5 | # Initialize logging format before loading all other modules 6 | logger = init_logging(logging.INFO) 7 | 8 | import argparse 9 | import json 10 | import os 11 | from multiprocessing import resource_tracker, shared_memory 12 | 13 | from ..lib import pytorch as lib_pytorch 14 | from ..lib.config import make_op_config 15 | from ..lib.pytorch.build_executor import MaterializedBuildExecutor 16 | from ..lib.pytorch.config_util import ExecutionPass, OpExecutionMode 17 | from ..workloads import pytorch as workloads_pytorch 18 | 19 | 20 | def main(): 21 | # Load PyTorch implementations for data generator and operators. 22 | load_modules(lib_pytorch) 23 | 24 | # Load PyTorch operator workloads. 25 | load_modules(workloads_pytorch) 26 | 27 | parser = argparse.ArgumentParser(description="Microbenchmarks") 28 | parser.add_argument( 29 | "-s", 30 | "--shm", 31 | type=str, 32 | required=False, 33 | help="The shared memory buffer name for the config.", 34 | ) 35 | parser.add_argument( 36 | "-f", "--file", type=str, required=False, help="The file name for the config." 37 | ) 38 | parser.add_argument( 39 | "-v", "--verbose", action="store_true", help="Increase log output verbosity." 40 | ) 41 | 42 | args = parser.parse_args() 43 | 44 | if args.verbose: 45 | init_logging(logging.DEBUG) 46 | 47 | if args.shm: 48 | """ 49 | Shared memory has a bug to proper track and release memory, see 50 | https://bugs.python.org/issue39959 51 | Fixed PR: https://github.com/python/cpython/pull/20136 52 | Workaround: unregister from resource_tracker. 53 | """ 54 | shm = shared_memory.SharedMemory(args.shm) 55 | logger.debug(f"shared memory: {shm.name}") 56 | resource_tracker.unregister(shm._name, "shared_memory") 57 | config = json.loads(bytes(shm.buf[:]).decode("utf-8", "strict")) 58 | shm.close() 59 | elif args.file: 60 | with open(args.file) as config_file: 61 | config = json.load(config_file) 62 | else: 63 | logger.info("no inputs provided.") 64 | return 65 | 66 | op_name = config["op_name"] 67 | config_build_id = config["config_build_id"] 68 | op_info = config["op_info"] 69 | run_options = config["run_options"] 70 | 71 | logger.debug(f"op_name: {op_name}") 72 | logger.debug(f"config_build_id: {config_build_id}") 73 | logger.debug(f"op_info: {op_info}") 74 | logger.debug(f"run_options: {run_options}") 75 | 76 | run_options["pass_type"] = ExecutionPass(run_options["pass_type"]) 77 | run_options["op_exec_mode"] = OpExecutionMode(run_options["op_exec_mode"]) 78 | 79 | op_config = make_op_config(op_name, op_info, run_options["device"]) 80 | build_input_config = op_info["config"][0] 81 | 82 | # Don't need to write out anything. 83 | with open(os.devnull, "w") as out_stream: 84 | run_options["out_stream"] = out_stream 85 | build_exe = MaterializedBuildExecutor(run_options) 86 | build_exe.run(op_config, build_input_config, config_build_id) 87 | 88 | 89 | if __name__ == "__main__": 90 | main() 91 | -------------------------------------------------------------------------------- /train/compute/python/requirements.txt: -------------------------------------------------------------------------------- 1 | fbgemm_gpu 2 | gitpython 3 | networkx 4 | numpy 5 | pydot 6 | scipy 7 | torch>=2.0.0 8 | -------------------------------------------------------------------------------- /train/compute/python/setup.py: -------------------------------------------------------------------------------- 1 | from lib import __generate_git_param_train_compute_version 2 | from setuptools import setup 3 | 4 | 5 | def main(): 6 | package_base = "param_bench.train.compute.python" 7 | 8 | # List the packages and their dir mapping: 9 | # "install_destination_package_path": "source_dir_path" 10 | package_dir_map = { 11 | f"{package_base}": ".", 12 | f"{package_base}.examples": "examples", 13 | f"{package_base}.examples.pytorch": "examples/pytorch", 14 | f"{package_base}.lib": "lib", 15 | f"{package_base}.lib.pytorch": "lib/pytorch", 16 | f"{package_base}.pytorch": "pytorch", 17 | f"{package_base}.test": "test", 18 | f"{package_base}.test.pytorch": "test/pytorch", 19 | f"{package_base}.tools": "tools", 20 | f"{package_base}.workloads": "workloads", 21 | f"{package_base}.workloads.pytorch": "workloads/pytorch", 22 | } 23 | 24 | packages = list(package_dir_map) 25 | 26 | param_train_compute_version = __generate_git_param_train_compute_version() 27 | with open("./lib/_version.py", "w") as version_out: 28 | version_out.write( 29 | f"__param_train_compute_version='{param_train_compute_version}'" 30 | ) 31 | 32 | setup( 33 | name="parambench-train-compute", 34 | version=param_train_compute_version, 35 | python_requires=">=3.8", 36 | author="Louis Feng", 37 | author_email="lofe@fb.com", 38 | url="https://github.com/facebookresearch/param", 39 | packages=packages, 40 | package_dir=package_dir_map, 41 | ) 42 | 43 | 44 | if __name__ == "__main__": 45 | main() 46 | -------------------------------------------------------------------------------- /train/compute/python/test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/param/365dc0409e6f66a1823199a62553a1ed558693b6/train/compute/python/test/__init__.py -------------------------------------------------------------------------------- /train/compute/python/test/pytorch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/param/365dc0409e6f66a1823199a62553a1ed558693b6/train/compute/python/test/pytorch/__init__.py -------------------------------------------------------------------------------- /train/compute/python/test/test_benchmark_load.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import unittest 4 | 5 | from param_bench.train.compute.python.lib import pytorch as lib_pytorch 6 | 7 | from param_bench.train.compute.python.lib.config import BenchmarkConfig 8 | from param_bench.train.compute.python.lib.init_helper import load_modules 9 | from param_bench.train.compute.python.lib.pytorch.benchmark import ( 10 | Benchmark, 11 | make_default_benchmark, 12 | ) 13 | from param_bench.train.compute.python.lib.pytorch.config_util import ( 14 | get_benchmark_options, 15 | ) 16 | 17 | CURR_DIR = os.path.dirname(os.path.realpath(__file__)) 18 | 19 | 20 | class TestBenchmarkLoad(unittest.TestCase): 21 | def setUp(self): 22 | self.config_path = os.path.join( 23 | CURR_DIR, "pytorch", "configs", "test_native_basic_ops.json" 24 | ) 25 | # Load PyTorch implementations for data generator and operators. 26 | load_modules(lib_pytorch) 27 | 28 | def test_json_load_benchmark(self): 29 | run_options = get_benchmark_options() 30 | bench_config = BenchmarkConfig(run_options) 31 | bench_config.load_json_file(self.config_path) 32 | benchmark = make_default_benchmark(bench_config) 33 | self.assertTrue(isinstance(benchmark, Benchmark)) 34 | self.assertTrue(len(benchmark.run_options) > 0) 35 | self.assertTrue(len(benchmark.bench_config.bench_config) > 0) 36 | 37 | 38 | if __name__ == "__main__": 39 | unittest.main() 40 | -------------------------------------------------------------------------------- /train/compute/python/test/test_generator.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import unittest 3 | 4 | from param_bench.train.compute.python.lib.generator import ( 5 | full_range, 6 | IterableList, 7 | ListProduct, 8 | TableProduct, 9 | ) 10 | 11 | 12 | class TestGenerator(unittest.TestCase): 13 | def test_full_range(self): 14 | def gen(start, end, step): 15 | result = [] 16 | x = full_range(start, end, step) 17 | for i in x: 18 | result.append(i) 19 | return result 20 | 21 | result = gen(-3, 2, 1) 22 | expected = [-3, -2, -1, 0, 1, 2] 23 | self.assertEqual(result, expected) 24 | 25 | expected = [5, 7, 9, 11] 26 | result = gen(5, 11, 2) 27 | self.assertEqual(result, expected) 28 | 29 | result = gen(3, 11, 3) 30 | expected = [3, 6, 9] 31 | self.assertEqual(result, expected) 32 | 33 | def test_iterable_List(self): 34 | simple_list = IterableList([2, 4, 6, 8]) 35 | result = [] 36 | for item in simple_list: 37 | result.append(item) 38 | expected = [2, 4, 6, 8] 39 | 40 | self.assertEqual(result, expected) 41 | 42 | # Case with IterableList with nested ListProduct 43 | simple_list = IterableList([1, 2, ListProduct([1, full_range(3, 5)])]) 44 | result = [] 45 | for item in simple_list: 46 | if isinstance(item, ListProduct): 47 | this_result = [] 48 | for sub_item in item: 49 | this_result.append(copy.deepcopy(sub_item)) 50 | result.append(this_result) 51 | else: 52 | result.append(item) 53 | expected = [1, 2, [[1, 3], [1, 4], [1, 5]]] 54 | 55 | self.assertEqual(result, expected) 56 | 57 | # Empty List Case 58 | simple_list = IterableList([]) 59 | result = [] 60 | for item in simple_list: 61 | result.append(item) 62 | expected = [] 63 | 64 | self.assertEqual(result, expected) 65 | 66 | def test_list_product(self): 67 | iter_list = [1, full_range(3, 5), 2, full_range(7, 13, 3)] 68 | result = [] 69 | for gen_list in ListProduct(iter_list): 70 | result.append(copy.deepcopy(gen_list)) 71 | expected = [ 72 | [1, 3, 2, 7], 73 | [1, 3, 2, 10], 74 | [1, 3, 2, 13], 75 | [1, 4, 2, 7], 76 | [1, 4, 2, 10], 77 | [1, 4, 2, 13], 78 | [1, 5, 2, 7], 79 | [1, 5, 2, 10], 80 | [1, 5, 2, 13], 81 | ] 82 | self.assertEqual(result, expected) 83 | 84 | iter_list = [ 85 | 1, 86 | ListProduct([2, full_range(3, 5)]), 87 | 6.7, 88 | TableProduct( 89 | { 90 | "A": ListProduct([6, full_range(2, 7, 2)]), 91 | "B": IterableList(["str 1", "str 2"]), 92 | } 93 | ), 94 | "str 3", 95 | ] 96 | result = [] 97 | for gen_list in ListProduct(iter_list): 98 | result.append(copy.deepcopy(gen_list)) 99 | expected = [ 100 | [1, [2, 3], 6.7, {"A": [6, 2], "B": "str 1"}, "str 3"], 101 | [1, [2, 3], 6.7, {"A": [6, 2], "B": "str 2"}, "str 3"], 102 | [1, [2, 3], 6.7, {"A": [6, 4], "B": "str 1"}, "str 3"], 103 | [1, [2, 3], 6.7, {"A": [6, 4], "B": "str 2"}, "str 3"], 104 | [1, [2, 3], 6.7, {"A": [6, 6], "B": "str 1"}, "str 3"], 105 | [1, [2, 3], 6.7, {"A": [6, 6], "B": "str 2"}, "str 3"], 106 | [1, [2, 4], 6.7, {"A": [6, 2], "B": "str 1"}, "str 3"], 107 | [1, [2, 4], 6.7, {"A": [6, 2], "B": "str 2"}, "str 3"], 108 | [1, [2, 4], 6.7, {"A": [6, 4], "B": "str 1"}, "str 3"], 109 | [1, [2, 4], 6.7, {"A": [6, 4], "B": "str 2"}, "str 3"], 110 | [1, [2, 4], 6.7, {"A": [6, 6], "B": "str 1"}, "str 3"], 111 | [1, [2, 4], 6.7, {"A": [6, 6], "B": "str 2"}, "str 3"], 112 | [1, [2, 5], 6.7, {"A": [6, 2], "B": "str 1"}, "str 3"], 113 | [1, [2, 5], 6.7, {"A": [6, 2], "B": "str 2"}, "str 3"], 114 | [1, [2, 5], 6.7, {"A": [6, 4], "B": "str 1"}, "str 3"], 115 | [1, [2, 5], 6.7, {"A": [6, 4], "B": "str 2"}, "str 3"], 116 | [1, [2, 5], 6.7, {"A": [6, 6], "B": "str 1"}, "str 3"], 117 | [1, [2, 5], 6.7, {"A": [6, 6], "B": "str 2"}, "str 3"], 118 | ] 119 | self.assertEqual(result, expected) 120 | 121 | iter_list = [] 122 | result = [] 123 | for gen_list in ListProduct(iter_list): 124 | result.append(gen_list) 125 | expected = [[]] 126 | self.assertEqual(result, expected) 127 | 128 | def test_table_product(self): 129 | iter_dict = {"A": 1, "B": full_range(3, 5), "C": 2, "D": full_range(7, 13, 3)} 130 | result = [] 131 | for gen_dict in TableProduct(iter_dict): 132 | result.append(copy.deepcopy(gen_dict)) 133 | expected = [ 134 | {"A": 1, "B": 3, "C": 2, "D": 7}, 135 | {"A": 1, "B": 3, "C": 2, "D": 10}, 136 | {"A": 1, "B": 3, "C": 2, "D": 13}, 137 | {"A": 1, "B": 4, "C": 2, "D": 7}, 138 | {"A": 1, "B": 4, "C": 2, "D": 10}, 139 | {"A": 1, "B": 4, "C": 2, "D": 13}, 140 | {"A": 1, "B": 5, "C": 2, "D": 7}, 141 | {"A": 1, "B": 5, "C": 2, "D": 10}, 142 | {"A": 1, "B": 5, "C": 2, "D": 13}, 143 | ] 144 | self.assertEqual(result, expected) 145 | 146 | iter_dict = { 147 | "A": 1, 148 | "B": ListProduct([3, full_range(4, 5)]), 149 | "C": 2, 150 | "D": TableProduct({"E": full_range(7, 9), "F": 10}), 151 | } 152 | result = [] 153 | for gen_dict in TableProduct(iter_dict): 154 | result.append(copy.deepcopy(gen_dict)) 155 | expected = [ 156 | {"A": 1, "B": [3, 4], "C": 2, "D": {"E": 7, "F": 10}}, 157 | {"A": 1, "B": [3, 4], "C": 2, "D": {"E": 8, "F": 10}}, 158 | {"A": 1, "B": [3, 4], "C": 2, "D": {"E": 9, "F": 10}}, 159 | {"A": 1, "B": [3, 5], "C": 2, "D": {"E": 7, "F": 10}}, 160 | {"A": 1, "B": [3, 5], "C": 2, "D": {"E": 8, "F": 10}}, 161 | {"A": 1, "B": [3, 5], "C": 2, "D": {"E": 9, "F": 10}}, 162 | ] 163 | self.assertEqual(result, expected) 164 | 165 | iter_dict = {} 166 | result = [] 167 | for gen_dict in TableProduct(iter_dict): 168 | result.append(gen_dict) 169 | expected = [{}] 170 | self.assertEqual(result, expected) 171 | 172 | 173 | if __name__ == "__main__": 174 | unittest.main() 175 | -------------------------------------------------------------------------------- /train/compute/python/test/test_register.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from param_bench.train.compute.python.lib.data import ( 4 | data_generator_map, 5 | DataGenerator, 6 | register_data_generator, 7 | ) 8 | from param_bench.train.compute.python.lib.iterator import ( 9 | config_iterator_map, 10 | ConfigIterator, 11 | register_config_iterator, 12 | ) 13 | from param_bench.train.compute.python.lib.operator import ( 14 | op_map, 15 | OperatorInterface, 16 | register_operator, 17 | register_operators, 18 | ) 19 | 20 | 21 | class TestRegister(unittest.TestCase): 22 | def test_register_config_iterator(self): 23 | class TestConfigIterator(ConfigIterator): 24 | pass 25 | 26 | name = "__TestConfigIterator__" 27 | register_config_iterator(name, TestConfigIterator) 28 | self.assertTrue(name in config_iterator_map) 29 | self.assertRaises( 30 | ValueError, register_config_iterator, name, TestConfigIterator 31 | ) 32 | 33 | def test_register_data_generator(self): 34 | class TestDataGenerator(DataGenerator): 35 | pass 36 | 37 | name = "__TestDataGenerator__" 38 | register_data_generator(name, TestDataGenerator) 39 | self.assertTrue(name in data_generator_map) 40 | self.assertRaises(ValueError, register_data_generator, name, TestDataGenerator) 41 | 42 | def test_register_operator(self): 43 | class TestOperator(OperatorInterface): 44 | pass 45 | 46 | name = "__TestOperator__" 47 | register_operator(name, TestOperator) 48 | self.assertTrue(name in op_map) 49 | self.assertRaises(ValueError, register_operator, name, TestOperator) 50 | 51 | name_1 = "__TestOperator_1__" 52 | name_2 = "__TestOperator_2__" 53 | register_operators({name_1: TestOperator, name_2: TestOperator}) 54 | self.assertTrue(name_1 in op_map) 55 | self.assertTrue(name_2 in op_map) 56 | self.assertRaises(ValueError, register_operators, {name_1: TestOperator}) 57 | self.assertRaises(ValueError, register_operators, {name_2: TestOperator}) 58 | self.assertRaises( 59 | ValueError, register_operators, {name_1: TestOperator, name_2: TestOperator} 60 | ) 61 | 62 | 63 | if __name__ == "__main__": 64 | unittest.main() 65 | -------------------------------------------------------------------------------- /train/compute/python/test/test_split_table_batched_embeddings_ops.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from param_bench.train.compute.python.lib.config import make_op_config 4 | from param_bench.train.compute.python.lib.pytorch.config_util import create_op_info 5 | from param_bench.train.compute.python.workloads.pytorch import ( # noqa 6 | split_table_batched_embeddings_ops, 7 | ) 8 | 9 | 10 | class TestSplitTableBatchedEmbeddingOps(unittest.TestCase): 11 | @unittest.skip("fbgemm is failing.") 12 | def test_build_op(self): 13 | op_name = "SplitTableBatchedEmbeddingBagsCodegen" 14 | op_info = create_op_info() 15 | op_info["input_data_generator"] = ( 16 | "SplitTableBatchedEmbeddingBagsCodegenInputDataGenerator" 17 | ) 18 | op_config = make_op_config(op_name, op_info, "cpu") 19 | op_config.op.cleanup() 20 | 21 | op_config.op.build( 22 | 1, 23 | [1000], 24 | [64], 25 | 0, # PoolingMode.SUM 26 | False, 27 | "fp16", 28 | "exact_adagrad", 29 | ) 30 | self.assertEqual(op_config.op.op.embedding_specs[0][0], 1000) 31 | self.assertEqual(op_config.op.op.embedding_specs[0][1], 64) 32 | 33 | op_config.op.build( 34 | 1, 35 | 2000, 36 | 128, 37 | 0, # PoolingMode.SUM 38 | False, 39 | "fp16", 40 | "exact_adagrad", 41 | ) 42 | self.assertEqual(op_config.op.op.embedding_specs[0][0], 2000) 43 | self.assertEqual(op_config.op.op.embedding_specs[0][1], 128) 44 | 45 | op_config.op.build( 46 | 2, 47 | [1000, 2000], 48 | [64, 128], 49 | 0, # PoolingMode.SUM 50 | False, 51 | "fp16", 52 | "exact_adagrad", 53 | ) 54 | self.assertEqual(op_config.op.op.embedding_specs[0][0], 1000) 55 | self.assertEqual(op_config.op.op.embedding_specs[1][0], 2000) 56 | self.assertEqual(op_config.op.op.embedding_specs[0][1], 64) 57 | self.assertEqual(op_config.op.op.embedding_specs[1][1], 128) 58 | 59 | 60 | if __name__ == "__main__": 61 | unittest.main() 62 | -------------------------------------------------------------------------------- /train/compute/python/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/param/365dc0409e6f66a1823199a62553a1ed558693b6/train/compute/python/tools/__init__.py -------------------------------------------------------------------------------- /train/compute/python/workloads/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/param/365dc0409e6f66a1823199a62553a1ed558693b6/train/compute/python/workloads/__init__.py -------------------------------------------------------------------------------- /train/compute/python/workloads/pytorch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/param/365dc0409e6f66a1823199a62553a1ed558693b6/train/compute/python/workloads/pytorch/__init__.py -------------------------------------------------------------------------------- /train/compute/python/workloads/pytorch/alex_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from ...lib.operator import register_operator 5 | from ...lib.pytorch.operator_impl import BuildableOp 6 | 7 | 8 | class AlexNet(nn.Module): 9 | """ 10 | Ref: https://pytorch.org/vision/master/_modules/torchvision/models/alexnet.html 11 | """ 12 | 13 | def __init__(self, num_classes: int = 1000, dropout: float = 0.5) -> None: 14 | super().__init__() 15 | self.features = nn.Sequential( 16 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), 17 | nn.ReLU(inplace=True), 18 | nn.MaxPool2d(kernel_size=3, stride=2), 19 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 20 | nn.ReLU(inplace=True), 21 | nn.MaxPool2d(kernel_size=3, stride=2), 22 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 23 | nn.ReLU(inplace=True), 24 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 25 | nn.ReLU(inplace=True), 26 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 27 | nn.ReLU(inplace=True), 28 | nn.MaxPool2d(kernel_size=3, stride=2), 29 | ) 30 | self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) 31 | self.classifier = nn.Sequential( 32 | nn.Dropout(p=dropout), 33 | nn.Linear(256 * 6 * 6, 4096), 34 | nn.ReLU(inplace=True), 35 | nn.Dropout(p=dropout), 36 | nn.Linear(4096, 4096), 37 | nn.ReLU(inplace=True), 38 | nn.Linear(4096, num_classes), 39 | ) 40 | 41 | def forward(self, x: torch.Tensor) -> torch.Tensor: 42 | x = self.features(x) 43 | x = self.avgpool(x) 44 | x = torch.flatten(x, 1) 45 | x = self.classifier(x) 46 | return x 47 | 48 | 49 | register_operator("pytorch.model.alex_net", BuildableOp(AlexNet)) 50 | -------------------------------------------------------------------------------- /train/compute/python/workloads/pytorch/native_basic_ops.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | 5 | from ...lib.operator import OperatorInterface, register_operators 6 | from ...lib.pytorch.operator_impl import BuildableOp, CallableOp, UnaryOp 7 | 8 | 9 | # Unary 10 | unary_ops: dict[str, OperatorInterface] = { 11 | "torch.add_": UnaryOp("add_"), 12 | "torch.clamp_": UnaryOp("clamp_"), 13 | } 14 | register_operators(unary_ops) 15 | 16 | callable_ops: dict[str, OperatorInterface] = { 17 | "torch.add": CallableOp(torch.add), 18 | "torch.baddbmm": CallableOp(torch.baddbmm), 19 | "torch.bmm": CallableOp(torch.bmm), 20 | "torch.cat": CallableOp(torch.cat), 21 | "torch.matmul": CallableOp(torch.matmul), 22 | "torch.mean": CallableOp(torch.mean), 23 | "torch.mm": CallableOp(torch.mm), 24 | "torch.mul": CallableOp(torch.mul), 25 | "torch.nn.functional.relu": CallableOp(torch.nn.functional.relu), 26 | "torch.reshape": CallableOp(torch.reshape), 27 | } 28 | register_operators(callable_ops) 29 | 30 | 31 | buildable_ops: dict[str, OperatorInterface] = { 32 | "torch.nn.AdaptiveAvgPool2d": BuildableOp(torch.nn.AdaptiveAvgPool2d), 33 | "torch.nn.Conv2d": BuildableOp(torch.nn.Conv2d), 34 | "torch.nn.Dropout": BuildableOp(torch.nn.Dropout), 35 | "torch.nn.MaxPool2d": BuildableOp(torch.nn.MaxPool2d), 36 | "torch.nn.ReLU": BuildableOp(torch.nn.ReLU), 37 | "torch.nn.Linear": BuildableOp(torch.nn.Linear), 38 | } 39 | register_operators(buildable_ops) 40 | -------------------------------------------------------------------------------- /train/workloads/README.md: -------------------------------------------------------------------------------- 1 | # Distributed DLRM 2 | 3 | The implementation is developed based on [DLRM dist_exp branch](https://github.com/facebookresearch/dlrm/tree/dist_exp) 4 | and add Facebook features and optimizations. 5 | 6 | Currently you need to download [the following PR](https://github.com/facebookresearch/dlrm/pull/127) 7 | to get the latest update. (Will be fixed soon.) 8 | 9 | ## Usage 10 | 11 | Currently, it is launched with mpirun on multi-nodes. The hostfile need to be created or 12 | a host list should be given. The DLRM parameters should be given in the same way as single 13 | node master branch. 14 | ```bash 15 | mpirun -np 128 -hostfile hostfile python dlrm_s_pytorch.py ... 16 | ``` 17 | 18 | ## Example 19 | 20 | large_arch_emb=$(printf '14000%.0s' {1..64}) 21 | large_arch_emb=${large_arch_emb_ads//"01"/"0-1"} 22 | 23 | ```bash 24 | python dlrm_s_pytorch.py 25 | --arch-sparse-feature-size=128 26 | --arch-mlp-bot="2000-1024-1024-128" 27 | --arch-mlp-top="4096-4096-4096-1" 28 | --arch-embedding-size=$large_arch_emb 29 | --data-generation=random 30 | --loss-function=bce 31 | --round-targets=True 32 | --learning-rate=0.1 33 | --mini-batch-size=2048 34 | --print-freq=10240 35 | --print-time 36 | --test-mini-batch-size=16384 37 | --test-num-workers=16 38 | --num-indices-per-lookup-fixed=1 39 | --num-indices-per-lookup=100 40 | --arch-projection-size 30 41 | --use-gpu 42 | ``` 43 | 44 | Please check the README.md in the PR for more details. 45 | --------------------------------------------------------------------------------