├── .gitignore ├── CODE_OF_CONDUCT.md ├── INPUT_GUIDE.md ├── LICENSE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── commands.sh ├── fig └── taccl_workflow_main.png ├── setup.py └── taccl ├── __init__.py ├── __main__.py ├── algorithm.py ├── cli ├── __init__.py ├── common.py ├── known_collectives.py ├── known_topologies.py ├── ncclize.py └── solve.py ├── collectives.py ├── examples ├── sketch │ ├── sk-dgx2-n1.json │ ├── sk-ndv2-n1-cUp6.json │ ├── sk1-dgx2-n1.json │ ├── sk1-dgx2-n2.json │ ├── sk1-ndv2-n1.json │ ├── sk1-ndv2-n2.json │ ├── sk1-ndv2-n4.json │ ├── sk2-dgx2-n2.json │ ├── sk2-ndv2-n2.json │ ├── sk2-ndv2-n4.json │ └── sk3-ndv2-n2.json └── topo │ ├── topo-dgx2-1KB.json │ ├── topo-dgx2-1MB.json │ ├── topo-ndv2-1KB.json │ ├── topo-ndv2-1MB.json │ └── topo-ndv2-32KB.json ├── heuristic_ordering.py ├── instance.py ├── ncclize.py ├── reduce_scheduler.py ├── routing.py ├── scheduler.py ├── serialization.py ├── shortest_path_sets.py ├── topologies ├── __init__.py ├── generic.py ├── route_sketch.py └── topology.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.ilp 3 | __pycache__ 4 | *.egg-info -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /INPUT_GUIDE.md: -------------------------------------------------------------------------------- 1 | # Writing your own topology profiles 2 | ## How to provide profiling information for topology? 3 | ### Known topologies 4 | For any "\" other than "custom", the link connection matrix is already defined in `./taccl/topologies/generic.py` and the topo-file only needs to have the node-config and link-profile information. 5 | 6 | Node config information would include the following: 7 | - `name`: the id you want to give to algorithm 8 | - `gpus_per_node`: number of GPUs in one node 9 | - `nics_per_node`: number of NICs in a node 10 | 11 | Link profile would include the following: 12 | - `alpha`: alpha-cost of the intra-node links 13 | - `node_betas_list`: list of beta-cost of the intra-node links in an increasing order (we use a list because there can be multiple types of links within a node too, like in NVIDIA DGX-1 nodes) 14 | - `node_invbws_list`: list of total cost (alpha + beta) of the intra-node links 15 | - `remote_alpha`: alpha-cost of an inter-node link 16 | - `remote_beta`: beta-cost of an inter-node link 17 | - `remote_invbw`: remote_alpha + remote_beta 18 | 19 | Guidelines for providing link profiles: 20 | 1. Beta values are obtained by multiplying a beta (in us/MB) by the input size (MB) for which you are trying to generate an algorithm. 21 | 2. Please ensure that all link profile costs have a big-enough integral part, since TACCL's ILP encoding will be rounding down the costs to integers in some stages and we would not want to lose information of the costs. For example, if your profile tuple is (alpha=0.3, node_betas_list=[0.5], node_invbws_list=[0.8], remote_alpha=1.5, remote_beta=2, remote_invbw=3.5), then you should multiply all values by some small factor (like 10), so that intra-node link costs don't all become 0. This will not change the synthesis problem. 22 | 23 | ### Custom topologies 24 | In case the node topology is different from the topologies provided in the KnownTopologies class, you can set \ as "custom" and provide a "links", "betas", and "invbws" matrix instead of the list of values "node_betas_list" and "node_invbws_list" in the topology-file input. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TACCL: Guiding Collective Algorithm Synthesis using Communication Sketches 2 | 3 |
4 | 5 |
6 | 7 | 8 | > **TACCL: Guiding Collective Algorithm Synthesis using Communication Sketches**
9 | > Aashaka Shah, Vijay Chidambaram, Meghan Cowan, Saeed Maleki, Madan Musuvathi, Todd Mytkowicz, Jacob Nelson, Olli Saarikivi, Rachee Singh
10 | > **NSDI 2023** [https://arxiv.org/pdf/2111.04867.pdf] 11 | 12 | TACCL is a tool to generate algorithms for collective communication like AllGather, AllToAll, and AllReduce for any given hardware configuration. TACCL takes a human-in-loop approach to algorithm design in which a user provides _communication sketches_ to guide the synthesis process. TACCL outputs TACCL-EF, an execution format that contains the schedule of GPU data transfers to efficiently implement the target collective. TACCL's schedules can be run by registering TACCL-EF files using the [MSCCL](https://github.com/microsoft/msccl-tools) tool stack. 13 | 14 | 15 | ## Installation 16 | We use Gurobi to solve the optimization problem. Please obtain a [Gurobi license](https://www.gurobi.com/downloads/) online and then proceed to install the Gurobi licensing tools as follows. 17 | Within an anaconda environment, run: 18 | ``` 19 | conda config --add channels http://conda.anaconda.org/gurobi 20 | conda install -c conda-forge gurobi -y 21 | 22 | ``` 23 | Finally, run 24 | ``` 25 | pip install . 26 | ``` 27 | 28 | 29 | ## Usage 30 | 31 | ### Generating the algorithm 32 | To generate data transfer algorithm for a collective for a specific topology , please provide a topology file containing profiling information of links in the node and a sketch file specifying the communication sketch. Using these, a JSON file of data transfer steps for the collective algorithm can be generated with the following command: 33 | 34 | ``` 35 | $ taccl solve --topology-file --sketch-file -o 36 | ``` 37 | 38 | For example, in order to generate an algorithm for Allgather used in the Evaluation for _dgx2-sk-1_, we run the following command: 39 | ``` 40 | $ cd taccl/examples 41 | $ taccl solve DGX2 Allgather --topology-file ../taccl/examples/topo/topo-dgx2-1MB.json --sketch-file ../taccl/examples/sketch/sk1-dgx2-n2.json 42 | ``` 43 | 44 | To generate schedules for the combining collective AllReduce, we first obtain an AllGather algorithm and use it to generate the AllReduce algorithm. is the timestamp that is used to save the send_dict of AllGather algorithm and can be obtained from the suffix of the json algorithm file. 45 | ``` 46 | $ cd taccl/examples 47 | $ taccl solve DGX2 Allgather --topology-file ../taccl/examples/topo/topo-dgx2-1MB.json --sketch-file ../taccl/examples/sketch/sk1-dgx2-n2.json 48 | $ taccl combine DGX2 Allgather --topology-file ../taccl/examples/topo/topo-dgx2-1MB.json --sketch-file ../taccl/examples/sketch/sk1-dgx2-n2.json --ts 49 | ``` 50 | 51 | `./commands.sh` shows the commands with sketches and profiled topologies that can be used to obtain the results in our paper. 52 | 53 | 54 | #### Providing topology 55 | - "\" should be selected from the 'known_topologies' constructor and can be either of "custom", "HubAndSpoke", "DGX2", or "NDv2". `./taccl/topologies/generic.py` defines the link connection matrix within a node for each \. `links[dst][src]` is 1 if there is one link going from `src` to `dst`. 56 | - "\" contains profiling information about the node, like latency and bandwidth costs of intra-node and inter-node transfers as well as number of GPUs and NICs in the node. Depending on the profiling information given in the topology-file provided by the user, the profile attributes `alpha`, `betas`, `invbws`, `nics_per_node`, `remote_invbw`, `remote_alpha`, and `remote_beta` are used to obtain a NodeTopology object. `taccl/examples/topo` directory gives examples of topology files for DGX-2 and NDv2 nodes that can be provided to `--topology-file`. 57 | 58 | You can add new node topologies as a separate function in `./taccl/cli/known_topologies.py` or use the "custom" option for topologies and provide all details of the topology in the user-provided topology-file. 59 | 60 | #### Providing communication sketch 61 | A communication sketch has the following three purposes: 62 | 1. Create a logical topology that determines how different nodes are connected to each other 63 | 2. Annotate links which form a switch 64 | 3. Annotate symmetry planes in the topology 65 | 66 | `./taccl/examples/sketch` directory gives examples of some communication sketches that can be used for NVIDIA DGX-2 and Azure NDv2 nodes. 67 | 68 | ### Lowering to TACCL-EF 69 | Once the algorithm is generated, we lower it into a TACCL-EF file. The number of instances \ determines the multiple we will use to increase the number of channels that are used to perform sends and receives. When there are already many threadblocks being used per GPU, you should use a single instance for the best performance. 70 | ``` 71 | $ taccl ncclize --instances 72 | ``` 73 | 74 | ### Running with MSCCL 75 | The TACCL-EF file can be used as input by the [MSCCL runtime](https://github.com/microsoft/msccl) to actually run the algorithm on the hardware. Please follow the setup instructions in the MSCCL repository to run TACCL algorithms. Once MSCCL is setup, TACCL-EF files can be benchmarked using nccl-tests as follows: 76 | ``` 77 | $ mpirun -np -x LD_LIBRARY_PATH=msccl/build/lib/:$LD_LIBRARY_PATH -x NCCL_DEBUG=INFO -x NCCL_DEBUG_SUBSYS=INIT,ENV -x MSCCL_XML_FILES= -x NCCL_ALGO=MSCCL,RING,TREE nccl-tests/build/ -b 128 -e 32MB -f 2 -g 1 -c 1 -n 100 -w 100 -G 100 -z 0 78 | ``` 79 | TACCL-EF files can also be registered in the [MSCCL-toolkit](https://github.com/microsoft/msccl-tools) to be used in frameworks like PyTorch and Tensorflow. 80 | 81 | ## Guide to providing topology profiles 82 | [INPUT_GUIDE.md](./INPUT_GUIDE.md) provides more details of how to specify different profiles for topologies. 83 | 84 | 85 | ## Citation 86 | > Shah, A., Chidambaram, V., Cowan, M., Maleki, S., Musuvathi, M., Mytkowicz, T., Nelson, J. and Saarikivi, O., 2023. {TACCL}: Guiding Collective Algorithm Synthesis using Communication Sketches. In 20th USENIX Symposium on Networked Systems Design and Implementation (NSDI 23) (pp. 593-612). 87 | 88 | ## Contributing 89 | 90 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 91 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 92 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 93 | 94 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 95 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 96 | provided by the bot. You will only need to do this once across all repos using our CLA. 97 | 98 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 99 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 100 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 101 | 102 | ## Trademarks 103 | 104 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft 105 | trademarks or logos is subject to and must follow 106 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). 107 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. 108 | Any use of third-party trademarks or logos are subject to those third-party's policies. -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps. 7 | - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /commands.sh: -------------------------------------------------------------------------------- 1 | 2 | # [dgx2-allgather-sk1-n2]: 2-node NDVIDIA DGX-2 AllGather for 1MB data chunks using sketch-1 3 | taccl solve DGX2 Allgather --topology-file ../taccl/examples/topo/topo-dgx2-1MB.json --sketch-file ../taccl/examples/sketch/sk1-dgx2-n2.json 4 | 5 | # [dgx2-allgather-sk2-n2]: 2-node NDVIDIA DGX-2 AllGather for 1KB data chunks using sketch- 6 | taccl solve DGX2 Allgather --topology-file ../taccl/examples/topo/topo-dgx2-1KB.json --sketch-file ../taccl/examples/sketch/sk2-dgx2-n2.json 7 | 8 | # [ndv2-allgather-sk1-n2]: 2-node Azure NDv2 AllGather for 1MB data chunks using sketch-1 9 | taccl solve NDv2 Allgather --topology-file ../taccl/examples/topo/topo-ndv2-1MB.json --sketch-file ../taccl/examples/sketch/sk1-ndv2-n2.json 10 | 11 | # [dgx2-alltoall-sk2-n2]: 2-node NVIDIA DGX-2 AlltoAll for 1KB data chunks using sketch-2 12 | taccl solve DGX2 Alltoall --topology-file ../taccl/examples/topo/topo-dgx2-1KB.json --sketch-file ../taccl/examples/sketch/sk2-dgx2-n2.json 13 | 14 | # dgx2-alltoall-sk3-n2 15 | taccl solve DGX2 Alltoall --topology-file ../taccl/examples/topo/topo-dgx2-1KB.json --sketch-file ../taccl/examples/sketch/sk3-dgx2-n2.json 16 | 17 | # ndv2-alltoall-sk1-n2 18 | taccl solve NDv2 Alltoall --topology-file ../taccl/examples/topo/topo-ndv2-1MB.json --sketch-file ../taccl/examples/sketch/sk1-ndv2-n2.json 19 | 20 | # ndv2-alltoall-sk2-n2 21 | taccl solve NDv2 Alltoall --topology-file ../taccl/examples/topo/topo-ndv2-1KB.json --sketch-file ../taccl/examples/sketch/sk2-ndv2-n2.json 22 | 23 | # dgx2-allreduce-sk1-n2 24 | taccl combine DGX2 Allgather --topology-file ../taccl/examples/topo/topo-dgx2-1MB.json --sketch-file ../taccl/examples/sketch/sk1-dgx2-n2.json --ts 25 | 26 | # dgx2-allreduce-sk2-n2 27 | taccl combine DGX2 Allgather --topology-file ../taccl/examples/topo/topo-dgx2-1KB.json --sketch-file ../taccl/examples/sketch/sk2-dgx2-n2.json --ts 28 | 29 | # ndv2-allreduce-sk1-n2 30 | taccl combine NDv2 Allgather --topology-file ../taccl/examples/topo/topo-ndv2-1MB.json --sketch-file ../taccl/examples/sketch/sk1-ndv2-n2.json --ts 31 | 32 | 33 | # ndv2-allgather-sk1-n4 34 | taccl combine NDv2 Allgather --topology-file ../taccl/examples/topo/topo-ndv2-1MB.json --sketch-file ../taccl/examples/sketch/sk1-ndv2-n4.json 35 | 36 | # ndv2-allgather-sk1-n6 37 | taccl solve DGX1 Allgather --topology-file ../taccl/examples/topo/topo-ndv2-1MB.json --sketch-file ../taccl/examples/sketch/sk1-ndv2-n6.json 38 | 39 | # ndv2-allgather-sk1-n8 40 | taccl solve DGX1 Allgather --topology-file ../taccl/examples/topo/topo-ndv2-1MB.json --sketch-file ../taccl/examples/sketch/sk1-ndv2-n8.json 41 | 42 | -------------------------------------------------------------------------------- /fig/taccl_workflow_main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/taccl/426bd1e30b7a170713f04f4cf20ace57e8c07251/fig/taccl_workflow_main.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from setuptools import setup, find_packages 5 | 6 | setup( 7 | name='taccl', 8 | version='2.0.0', 9 | packages=find_packages(), 10 | entry_points={ 11 | 'console_scripts': [ 12 | 'taccl = taccl.__main__:main', 13 | ], 14 | }, 15 | install_requires=[ 16 | 'dataclasses; python_version < "3.7"', 17 | 'z3-solver', 18 | 'argcomplete', 19 | 'lxml', 20 | 'gurobipy', 21 | 'numpy' 22 | ], 23 | python_requires='>=3.6', 24 | ) 25 | -------------------------------------------------------------------------------- /taccl/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. -------------------------------------------------------------------------------- /taccl/__main__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # PYTHON_ARGCOMPLETE_OK 3 | 4 | # Copyright (c) Microsoft Corporation. 5 | # Licensed under the MIT License. 6 | 7 | from taccl.cli import * 8 | 9 | import argparse 10 | import argcomplete 11 | import sys 12 | 13 | def main(): 14 | parser = argparse.ArgumentParser('taccl') 15 | 16 | cmd_parsers = parser.add_subparsers(title='command', dest='command') 17 | cmd_parsers.required = True 18 | 19 | handlers = [] 20 | handlers.append(make_handle_solve_comm_sketch(cmd_parsers)) 21 | handlers.append(make_handle_combine_comm_sketch(cmd_parsers)) 22 | handlers.append(make_handle_ncclize(cmd_parsers)) 23 | 24 | argcomplete.autocomplete(parser) 25 | args = parser.parse_args() 26 | 27 | for handler in handlers: 28 | if handler(args, args.command): 29 | break 30 | 31 | if __name__ == '__main__': 32 | main() 33 | -------------------------------------------------------------------------------- /taccl/algorithm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from dataclasses import dataclass 5 | from collections import defaultdict 6 | 7 | @dataclass 8 | class Step(object): 9 | rounds: int 10 | sends: list 11 | 12 | class Algorithm(object): 13 | def __init__(self, name, collective, topology, instance, steps, input_map = {}, output_map = {}, cont=False): 14 | self.name = name 15 | self.topology = topology 16 | self.collective = collective 17 | self.instance = instance 18 | self.steps = steps 19 | self.input_map = input_map 20 | self.output_map = output_map 21 | self.cont = cont 22 | 23 | self._update_link_utilizations() 24 | if not cont: 25 | self._check_bandwidth_constraints() 26 | else: 27 | self._check_real_bandwidth_constraints() 28 | 29 | @classmethod 30 | def make_implementation(cls, collective, topology, instance, steps, cont=False, suffix=""): 31 | # Figure out input and output addresses 32 | input_map = {} 33 | output_map = {} 34 | for rank in collective.ranks(): 35 | input_addrs = set() 36 | output_addrs = set() 37 | for chunk in collective.chunks(): 38 | # An address is an input address if any of its chunks is in the precondition 39 | if collective.precondition(rank, chunk): 40 | input_addrs.add(collective.address(chunk)) 41 | # An address is an output address if any of its chunks is in the postcondition 42 | if collective.postcondition(rank, chunk): 43 | output_addrs.add(collective.address(chunk)) 44 | if len(input_addrs) > 0: 45 | input_map[rank] = input_addrs 46 | if len(output_addrs) > 0: 47 | output_map[rank] = output_addrs 48 | 49 | # Concatenate collective and topology names plus instance arguments to create a name 50 | name = f'{collective.name}-{topology.name}-{instance}{suffix}' 51 | 52 | algo = cls(name, collective, topology, instance, steps, input_map, output_map, cont) 53 | algo.check_implements(collective) 54 | if instance.extra_rounds > 0: 55 | used_extra_rounds = algo.extra_rounds() 56 | if used_extra_rounds > instance.extra_rounds: 57 | raise ValueError(f'steps use {used_extra_rounds} extra rounds but only {instance.extra_rounds} were allowed') 58 | return algo 59 | 60 | def ranks(self): 61 | return range(self.topology.num_nodes()) 62 | 63 | def num_steps(self): 64 | return len(self.steps) 65 | 66 | def extra_rounds(self): 67 | rounds = 0 68 | for step in self.steps: 69 | rounds += step.rounds 70 | return rounds - self.num_steps() 71 | 72 | def is_pipelined(self): 73 | return self.instance.pipeline != None 74 | 75 | def check_implements(self, collective): 76 | if self.topology.num_nodes() != collective.num_nodes: 77 | raise RuntimeError('topology and collective have different number of nodes') 78 | # Find which chunks will be sent from an address 79 | chunks_at_address = defaultdict(list) 80 | for chunk in collective.chunks(): 81 | chunks_at_address[collective.address(chunk)].append(chunk) 82 | # State records if a rank holds a chunk 83 | def idx(rank, chunk): 84 | return rank * collective.num_chunks + chunk 85 | state = [False] * (collective.num_nodes * collective.num_chunks) 86 | # Initialize state from precondition 87 | for rank in collective.ranks(): 88 | for chunk in collective.chunks(): 89 | state[idx(rank, chunk)] = collective.precondition(rank, chunk) 90 | # Propagate state through sends of every step 91 | for step in self.steps: 92 | next_state = state.copy() 93 | if len(step.sends[0]) == 5: 94 | for addr, src, dst, _, _ in step.sends: 95 | for chunk in chunks_at_address[addr]: 96 | next_state[idx(dst, chunk)] |= state[idx(src, chunk)] 97 | elif len(step.sends[0]) == 6: 98 | for addr, src, dst, _, _, _ in step.sends: 99 | for chunk in chunks_at_address[addr]: 100 | next_state[idx(dst, chunk)] |= state[idx(src, chunk)] 101 | else: 102 | for addr, src, dst in step.sends: 103 | for chunk in chunks_at_address[addr]: 104 | next_state[idx(dst, chunk)] |= state[idx(src, chunk)] 105 | state = next_state 106 | # Check that the postcondition holds 107 | for rank in collective.ranks(): 108 | for chunk in collective.chunks(): 109 | # print(rank, chunk, state[idx(rank, chunk)]) 110 | if collective.postcondition(rank, chunk) and not state[idx(rank, chunk)]: 111 | raise RuntimeError(f'rank {rank} does not get chunk {chunk} as required by the postcondition') 112 | 113 | def _update_link_utilizations(self): 114 | self._link_utilizations = [] 115 | ranks = range(self.topology.num_nodes()) 116 | for step in self.steps: 117 | step_utilizations = [[0 for _ in ranks] for _ in ranks] 118 | if len(step.sends[0]) == 5: 119 | for addr, src, dst, _, _ in step.sends: 120 | step_utilizations[dst][src] += 1 # Same order as topology 121 | elif len(step.sends[0]) == 6: 122 | for addr, src, dst, _, _, _ in step.sends: 123 | step_utilizations[dst][src] += 1 # Same order as topology 124 | else: 125 | for addr, src, dst in step.sends: 126 | step_utilizations[dst][src] += 1 # Same order as topology 127 | self._link_utilizations.append(step_utilizations) 128 | 129 | def _check_bandwidth_constraints(self): 130 | for srcs, dsts, bw, name in self.topology.bandwidth_constraints(): 131 | for step_num, step in enumerate(self.steps): 132 | util = 0 133 | for dst in dsts: 134 | for src in srcs: 135 | if self.is_pipelined(): 136 | for overlapping_step in range(step_num, len(self.steps), self.instance.pipeline): 137 | util += self._link_utilizations[overlapping_step][dst][src] 138 | else: 139 | util += self._link_utilizations[step_num][dst][src] 140 | assert util <= bw * step.rounds, \ 141 | f'Step {step_num} uses {util} bandwidth but constraint {name} only allows for {bw * step.rounds} bandwidth (when rounds={step.rounds}).' 142 | 143 | def _check_real_bandwidth_constraints(self): 144 | for srcs, dsts, bw, l, name in self.topology.real_bandwidth_constraints(): 145 | for step_num, step in enumerate(self.steps): 146 | util = 0 147 | for dst in dsts: 148 | for src in srcs: 149 | if self.is_pipelined(): 150 | for overlapping_step in range(step_num, len(self.steps), self.instance.pipeline): 151 | util += self._link_utilizations[overlapping_step][dst][src] 152 | else: 153 | util += self._link_utilizations[step_num][dst][src] 154 | assert util * bw <= step.rounds, \ 155 | f'Step {step_num} uses {util * bw} time but constraint {name} only allows for {step.rounds} time (when rounds={step.rounds}).' 156 | 157 | 158 | def __str__(self): 159 | s = '' 160 | for i, step in enumerate(self.steps): 161 | if i != 0: 162 | s += '\n' 163 | if step.rounds > 1: 164 | s += f'(step {i+1}, rounds={step.rounds}) ' 165 | else: 166 | s += f'(step {i+1}) ' 167 | if len(step.sends[0]) == 5: 168 | s += ', '.join([f'{chunk}:{src}→{dst}' for chunk, src, dst, _, _ in step.sends]) 169 | elif len(step.sends[0]) == 6: 170 | s += ', '.join([f'{chunk}:{src}→{dst}' for chunk, src, dst, _, _, _ in step.sends]) 171 | else: 172 | s += ', '.join([f'{chunk}:{src}→{dst}' for chunk, src, dst in step.sends]) 173 | return s 174 | -------------------------------------------------------------------------------- /taccl/cli/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from .solve import * 5 | from .ncclize import * 6 | -------------------------------------------------------------------------------- /taccl/cli/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from taccl.serialization import * 5 | from taccl.instance import * 6 | from taccl.topologies import TACCLTopology, IntraNode, IntraNode_Switch, InterNode, InterNode_Relay, MultiNode, Symmetry, HyperParameter, RouteSketch, NodeTopology 7 | 8 | import json 9 | from pathlib import Path 10 | import sys 11 | import re 12 | from fractions import Fraction 13 | from collections import defaultdict 14 | 15 | def _legalize_sccl_name(name): 16 | name = name.replace('(', '.') 17 | name = name.replace('=', '') 18 | name = name.replace(',', '.') 19 | name = name.replace(')', '') 20 | return name 21 | 22 | def name_sccl_object(name, ending='sccl.json'): 23 | return f'{_legalize_sccl_name(name)}.{ending}' 24 | 25 | def _validate_output_directory(directory): 26 | if not directory.exists(): 27 | print('error: output directory does not exists', file=sys.stderr) 28 | exit(1) 29 | if not directory.is_dir(): 30 | print('error: output path is not a directory', file=sys.stderr) 31 | exit(1) 32 | 33 | def _handle_write_to_directory(directory, force, get_contents, preferred_file_name): 34 | output_file = directory / preferred_file_name 35 | if output_file.exists(): 36 | if output_file.is_dir(): 37 | print(f'error: output path is a directory', file=sys.stderr) 38 | exit(1) 39 | if force: 40 | print(f'Overwriting {output_file}') 41 | else: 42 | print(f'file already exists, use -f/--force to overwrite {output_file}', file=sys.stderr) 43 | return False 44 | with output_file.open('w') as f: 45 | f.write(get_contents()) 46 | print(f'Wrote to {output_file}') 47 | return True 48 | 49 | def add_output_file(parser): 50 | group = parser.add_mutually_exclusive_group() 51 | group.add_argument('-o', '--output', type=Path, help='file to write synthesized algorithm to', metavar='FILE') 52 | group.add_argument('-d', '--directory', type=Path, default=Path(), help='directory to write the synthesized algorithm to', metavar='DIR') 53 | parser.add_argument('-f', '--force', action='store_true', help='overwrite existing files') 54 | parser.add_argument('--no-save', action='store_true', help='do not save to file') 55 | 56 | def validate_args(args): 57 | if args.output != None: 58 | if args.output.is_dir(): 59 | print(f'error: output path is a directory, did you mean to use -d?', file=sys.stderr) 60 | exit(1) 61 | if args.directory != None: 62 | _validate_output_directory(args.directory) 63 | 64 | def handle(args, get_contents, preferred_file_name): 65 | if args.no_save: 66 | return False 67 | if args.output != None: 68 | if args.output.exists() and not args.force: 69 | print(f'file already exists, use -f/--force to overwrite {args.output}', file=sys.stderr) 70 | return False 71 | with args.output.open('w') as f: 72 | f.write(get_contents()) 73 | print(f'Wrote to {args.output}') 74 | else: 75 | return _handle_write_to_directory(args.directory, args.force, get_contents, preferred_file_name) 76 | return True 77 | 78 | return validate_args, handle 79 | 80 | def add_output_algorithm(parser): 81 | validate_args, handle_file = add_output_file(parser) 82 | 83 | def handle(args, algorithm): 84 | if algorithm == None: 85 | return # Strategies/distributors have their specific failure prints 86 | 87 | handled = handle_file(args, lambda: SCCLEncoder().encode(algorithm), name_sccl_object(algorithm.name)) 88 | if not handled: 89 | print(f'\n{algorithm.name} algorithm:') 90 | print(algorithm) 91 | 92 | return validate_args, handle 93 | 94 | def add_output_topology(parser): 95 | validate_args, handle_file = add_output_file(parser) 96 | 97 | def handle(args, topology): 98 | handled = handle_file(args, lambda: SCCLEncoder().encode(topology), name_sccl_object(topology.name)) 99 | 100 | return validate_args, handle 101 | 102 | def add_output_sccl_objects(parser): 103 | parser.add_argument('-d', '--directory', type=Path, default=Path(), help='directory to write outputs to', metavar='DIR') 104 | parser.add_argument('-f', '--force', action='store_true', help='overwrite existing files') 105 | parser.add_argument('--no-save', action='store_true', help='do not save to file') 106 | 107 | def validate_args(args): 108 | _validate_output_directory(args.directory) 109 | 110 | def handle(args, sccl_object, name): 111 | if not args.no_save: 112 | _handle_write_to_directory(args.directory, args.force, lambda: SCCLEncoder().encode(sccl_object), name_sccl_object(name)) 113 | 114 | return validate_args, handle 115 | 116 | def add_input_algorithm(parser, multiple=False, name='algorithm'): 117 | parser.add_argument(name, type=Path, nargs='+' if multiple else 1, help=f'algorithm to operate on') 118 | 119 | def read_algorithm(args): 120 | algos = [] 121 | for input_file in vars(args)[name]: 122 | if not input_file.exists(): 123 | print(f'error: input file not found: {input_file}', file=sys.stderr) 124 | exit(1) 125 | 126 | algo = load_sccl_object(input_file) 127 | algos.append(algo) 128 | if multiple: 129 | return algos 130 | else: 131 | return algos[0] 132 | 133 | return read_algorithm 134 | 135 | def add_instance(parser, take_steps=True, take_rounds=True, take_chunks=True): 136 | if take_steps: 137 | parser.add_argument('-s', '--steps', type=int, required=True) 138 | if take_rounds: 139 | parser.add_argument('-r', '--rounds', type=int, default=None, metavar='N') 140 | if take_chunks: 141 | parser.add_argument('-c', '--chunks', type=int, default=1, metavar='N') 142 | parser.add_argument('--pipeline', type=int, default=None, metavar='N') 143 | parser.add_argument('--extra-memory', type=int, default=None, metavar='N') 144 | parser.add_argument('--allow-exchange', action='store_true') 145 | 146 | def handle(args): 147 | if take_rounds: 148 | if args.rounds != None: 149 | if args.rounds < args.steps: 150 | parser.error(f'error: rounds cannot be less than steps ({args.rounds} < {args.steps})') 151 | extra_rounds = args.rounds - args.steps 152 | else: 153 | extra_rounds = 0 154 | return Instance( 155 | steps=args.steps if take_steps else None, 156 | extra_rounds=extra_rounds if take_rounds else 0, 157 | chunks=args.chunks if take_chunks else 1, 158 | pipeline=args.pipeline, 159 | extra_memory=args.extra_memory, 160 | allow_exchange=args.allow_exchange) 161 | 162 | return handle 163 | 164 | def parse_fraction(value): 165 | try: 166 | return int(value) 167 | except ValueError: 168 | m = re.fullmatch('(.+)/(.+)', value) 169 | if m == None: 170 | raise ValueError('value must be in format "/"') 171 | numerator = int(m.group(1)) 172 | denominator = int(m.group(2)) 173 | return Fraction(numerator, denominator) 174 | 175 | def make_cmd_category(cmd_parsers, name, title, handler_funcs): 176 | cmd = cmd_parsers.add_parser(name) 177 | category_parsers = cmd.add_subparsers(title=title, dest=title) 178 | category_parsers.required = True 179 | 180 | handlers = [] 181 | for func in handler_funcs: 182 | handlers.append(func(category_parsers)) 183 | 184 | def handle(args, command): 185 | if command != name: 186 | return False 187 | 188 | for handler in handlers: 189 | if handler(args, vars(args)[title]): 190 | return True 191 | 192 | return handle 193 | 194 | def _multiply_link_matrix(links, factor_matrix): 195 | new_links = [[links[dst][src] * factor_matrix[dst][src] for src in range(len(links))] for dst in range(len(links[0]))] 196 | return new_links 197 | 198 | def _div_beta_add_alpha(alpha, betas, factor_matrix): 199 | new_betas = [[ int(betas[dst][src] / factor_matrix[dst][src]) for src in range(len(betas))] for dst in range(len(betas[0]))] 200 | new_invbws = [[ int(alpha + betas[dst][src] / factor_matrix[dst][src]) for src in range(len(betas))] for dst in range(len(betas[0]))] 201 | return new_betas, new_invbws 202 | 203 | def _filter_links(links, conn): 204 | new_links = [ 205 | [ 206 | links[dst][src] 207 | if src in conn and dst in conn[src] 208 | else 0 209 | for src in range(len(links)) 210 | ] for dst in range(len(links[0])) 211 | ] 212 | return new_links 213 | 214 | def _filter_invbws(invbws, conn): 215 | new_invbws = [ 216 | [ 217 | invbws[dst][src] * len(conn[src]) 218 | if src in conn and dst in conn[src] 219 | else 0 220 | for src in range(len(invbws)) 221 | ] for dst in range(len(invbws[0])) 222 | ] 223 | return new_invbws 224 | 225 | 226 | 227 | def parse_and_get_topo(node_topology: NodeTopology, comm_sketch_file, reduce=False): 228 | cs_json = json.load(comm_sketch_file) 229 | copies = cs_json["nnodes"] 230 | ngpus_per_node = len(node_topology.links) 231 | switches = [] 232 | 233 | if cs_json["intranode_sketch"]["strategy"] == "switch": 234 | assert len(cs_json["intranode_sketch"]["switches"]) == len(cs_json["intranode_sketch"]["switch_hyperedge_strategy"]) 235 | intranode_sketch = IntraNode_Switch( 236 | cs_json["intranode_sketch"]["strategy"], 237 | cs_json["intranode_sketch"]["switches"], 238 | cs_json["intranode_sketch"]["switch_hyperedge_strategy"] 239 | ) 240 | switches = cs_json["intranode_sketch"]["switches"] 241 | 242 | intra_node_split = [[1 for _ in range(ngpus_per_node)] for _ in range(ngpus_per_node)] 243 | # Assert that the interesection of any two sets of switches is either empty or equal to the set of switches 244 | # This is required for correctly deriving the way to split the bandwidth in the node and update the links 245 | # Get the number of disjoint sets of switches and update the intra_node_split matrix 246 | added = [] 247 | intersections = {} 248 | for i in range(len(switches)): 249 | if (i not in added): 250 | num_same = 1 251 | for j in range(i+1, len(switches)): 252 | intersection = list(set(switches[i]) & set(switches[j])) 253 | assert len(intersection) == 0 or (len(intersection) == len(switches[i]) and len(intersection) == len(switches[j])) 254 | if len(intersection): 255 | num_same += 1 256 | added.append(j) 257 | for gpu_i in switches[i]: 258 | for gpu_j in switches[i]: 259 | if (gpu_i != gpu_j): 260 | intra_node_split[gpu_i][gpu_j] = num_same 261 | intra_node_split[gpu_j][gpu_i] = num_same 262 | added.append(i) 263 | for row in intra_node_split: 264 | print(row) 265 | new_switches = [[n for n in switches[i]] for i in added] 266 | switches = new_switches 267 | # Update the links and invbws 268 | new_links = _multiply_link_matrix(node_topology.links, intra_node_split) 269 | new_betas, new_invbws = _div_beta_add_alpha(node_topology.alpha, node_topology.betas, intra_node_split) 270 | node_topology.links = new_links 271 | node_topology.betas = new_betas 272 | node_topology.invbws = new_invbws 273 | elif cs_json["intranode_sketch"]["strategy"] == "maxmin" or cs_json["intranode_sketch"]["strategy"] == "minmax": 274 | intranode_sketch = IntraNode(cs_json["intranode_sketch"]["strategy"]) 275 | elif cs_json["intranode_sketch"]["strategy"] == "none": 276 | intranode_sketch = IntraNode("none") 277 | else: 278 | assert False, "No such intranode strategy available" 279 | 280 | if copies > 1: 281 | nics_per_node = node_topology.nics_per_node 282 | assert "internode_sketch" in cs_json 283 | assert cs_json["internode_sketch"]["strategy"] == "relay" 284 | assert "internode_conn" in cs_json["internode_sketch"] 285 | internode_conn = cs_json["internode_sketch"]["internode_conn"] 286 | if not isinstance(internode_conn, dict): 287 | assert isinstance(internode_conn, str) 288 | conns = defaultdict(list) 289 | if internode_conn == "fully-connected": 290 | for i in range(ngpus_per_node): 291 | for j in range(ngpus_per_node): 292 | conns[i].append(j) 293 | elif internode_conn == "direct-map": 294 | for i in range(ngpus_per_node): 295 | conns[i].append(i) 296 | else: 297 | assert False, "No such internode connection strategy" 298 | internode_conn = conns 299 | 300 | num_senders = len(internode_conn) 301 | # Number of outgoing connections is restricted to be the same for all sender GPUs 302 | num_dsts = -1 303 | for (src, dsts) in internode_conn.items(): 304 | if num_dsts == -1: 305 | num_dsts = len(dsts) 306 | else: 307 | assert num_dsts == len(dsts) 308 | total_outgoing_links = num_dsts * num_senders 309 | beta_split_factor = total_outgoing_links / nics_per_node 310 | node_topology.remote_beta = int(node_topology.remote_beta * beta_split_factor) 311 | node_topology.remote_invbw = int(node_topology.remote_alpha + node_topology.remote_beta) 312 | 313 | gpus_to_sender_rev_map = cs_json["internode_sketch"]["gpus_to_sender_rev_map"] if "gpus_to_sender_rev_map" in cs_json["internode_sketch"] else None 314 | enforce_ordering = cs_json["internode_sketch"]["enforce_ordering"] if "enforce_ordering" in cs_json["internode_sketch"] else False 315 | internode_sketch = InterNode_Relay( 316 | internode_conn, 317 | gpus_to_sender_rev_map, 318 | enforce_ordering, 319 | ) 320 | else: 321 | internode_conn = None 322 | internode_sketch = None 323 | 324 | multinode_sketch = MultiNode(["round-robin"], [1], [copies]) 325 | 326 | symmetry = Symmetry(cs_json["symmetry_offsets"]) 327 | 328 | if reduce: 329 | scheduling_heuristic = 12 330 | elif len(switches): 331 | scheduling_heuristic = 10 332 | elif copies > 1: 333 | scheduling_heuristic = 14 334 | else: 335 | scheduling_heuristic = 5 336 | 337 | hyperparameters = HyperParameter( 338 | cs_json["hyperparameters"]["input_chunkup"], 339 | scheduling_heuristic 340 | ) 341 | 342 | route_sketch = RouteSketch( 343 | intranode_sketch, 344 | internode_sketch, 345 | multinode_sketch, 346 | symmetry, 347 | hyperparameters 348 | ) 349 | # return route_sketch 350 | 351 | topology = TACCLTopology( 352 | name=node_topology.name, 353 | copies=copies, 354 | ngpus_per_node=ngpus_per_node, 355 | node_links=node_topology.links, 356 | node_invbws=node_topology.invbws, 357 | remote_invbw=node_topology.remote_invbw, 358 | remote_alpha=node_topology.remote_alpha, 359 | remote_beta=node_topology.remote_beta, 360 | internode_conn=internode_conn, 361 | switches=switches 362 | ) 363 | 364 | return topology, route_sketch 365 | -------------------------------------------------------------------------------- /taccl/cli/known_collectives.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import taccl.collectives as collectives 5 | from taccl.serialization import * 6 | from pathlib import Path 7 | import sys 8 | 9 | class KnownCollectives: 10 | def __init__(self, parser): 11 | self.parser = parser 12 | self.constructors = { 13 | 'Broadcast': self._rooted_coll(collectives.broadcast), 14 | 'Reduce': self._rooted_coll(collectives.reduce), 15 | 'Scatter': self._rooted_coll(collectives.scatter), 16 | 'Gather': self._rooted_coll(collectives.gather), 17 | 'Allgather': self._coll(collectives.allgather), 18 | 'Allreduce': self._coll(collectives.allreduce), 19 | 'Alltoall': self._coll(collectives.alltoall), 20 | 'ReduceScatter': self._coll(collectives.reduce_scatter), 21 | 'Scan': self._coll(collectives.scan), 22 | 'MultirootBroadcast': self._multiroot_coll(collectives.multiroot_broadcast), 23 | 'MultirootScatter': self._multiroot_coll(collectives.multiroot_scatter), 24 | 'MultirootGather': self._multiroot_coll(collectives.multiroot_gather), 25 | 'custom': self._custom_coll(), 26 | } 27 | self.parser.add_argument('collective', type=str, choices=self.constructors.keys(), help='collective') 28 | self.parser.add_argument('--collective-file', type=Path, default=None, help='a serialized collective', metavar='FILE') 29 | self.parser.add_argument('--root', type=int, default=0, help='used by rooted collectives', metavar='N') 30 | self.parser.add_argument('--roots', type=int, nargs='+', default=[0], help='used by multi-rooted collectives', metavar='N') 31 | 32 | def create(self, args, num_nodes): 33 | return self.constructors[args.collective](num_nodes, args) 34 | 35 | def _custom_coll(self): 36 | def make(size, args): 37 | input_file = args.collective_file 38 | if input_file is None: 39 | self.parser.error('--collective-file is required for custom collectives') 40 | exit(1) 41 | 42 | if not input_file.exists(): 43 | print(f'error: input file not found: {input_file}', file=sys.stderr) 44 | exit(1) 45 | 46 | return load_sccl_object(input_file) 47 | return make 48 | 49 | def _rooted_coll(self, fun): 50 | def make(size, args): 51 | root = args.root 52 | return fun(size, root) 53 | return make 54 | 55 | def _coll(self, fun): 56 | def make(size, args): 57 | return fun(size) 58 | return make 59 | 60 | def _multiroot_coll(self, fun): 61 | def make(size, args): 62 | roots = args.roots 63 | return fun(size, roots) 64 | return make 65 | -------------------------------------------------------------------------------- /taccl/cli/known_topologies.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import taccl.topologies as topologies 5 | from pathlib import Path 6 | import sys 7 | 8 | class KnownTopologies: 9 | def __init__(self, parser, tag=''): 10 | self.parser = parser 11 | self.tag = tag 12 | self.constructors = { 13 | 'HubAndSpoke': self._topo(topologies.hub_and_spoke), 14 | 'DGX2': self._topo(topologies.dgx2), 15 | 'NDv2': self._topo(topologies.ndv2), 16 | 'custom': self._topo(topologies.custom), 17 | } 18 | self.parser.add_argument(f'topology{tag}', type=str, choices=self.constructors.keys(), help=f'topology {tag}') 19 | self.parser.add_argument(f'--topology-file{tag}', type=str, default=None, help=f'profiled topology') 20 | 21 | def _topology(self, args): 22 | return vars(args)[f'topology{self.tag}'] 23 | 24 | def _topology_file(self, args): 25 | input_str = vars(args)[f'topology_file{self.tag}'] 26 | if input_str is None: 27 | self.parser.error(f'--topology-file{self.tag} is required') 28 | exit(1) 29 | 30 | input_file = Path(input_str) 31 | if not input_file.exists(): 32 | print(f'error: input file not found: {input_file}', file=sys.stderr) 33 | exit(1) 34 | 35 | return input_file 36 | 37 | def create(self, args): 38 | topology = self.constructors[self._topology(args)](args) 39 | return topology 40 | 41 | def _topo(self, Cls): 42 | def make(args): 43 | return Cls(self._topology_file(args)) 44 | return make 45 | 46 | -------------------------------------------------------------------------------- /taccl/cli/ncclize.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from taccl.ncclize import * 5 | from .common import * 6 | 7 | def make_handle_ncclize(cmd_parsers): 8 | cmd = cmd_parsers.add_parser('ncclize') 9 | read_algorithm = add_input_algorithm(cmd, multiple=True) 10 | validate_output_args, output_handler = add_output_file(cmd) 11 | remap_scratch_grp = cmd.add_mutually_exclusive_group() 12 | remap_scratch_grp.add_argument('--remap-scratch', action='store_true', default=None, help='remap scratch buffer indices into free input/output indices') 13 | remap_scratch_grp.add_argument('--no-remap-scratch', action='store_false', dest='remap_scratch', help='don\'t remap scratch buffer indices into free input/output indices') 14 | cmd.add_argument('--no-merge-contiguous', action='store_true', help='don\'t merge sends/receives from/to contiguous memory') 15 | cmd.add_argument('--no-pretty-print', action='store_true', help='don\'t pretty print the generated XML') 16 | cmd.add_argument('--extra-contig', action='store_true', help='allow lucky contiguity') 17 | cmd.add_argument('--channel-policy', type=ChannelPolicy, choices=list(ChannelPolicy), default=ChannelPolicy.MatchTopology, help='channel allocation policy') 18 | cmd.add_argument('--instances', type=int, default=1, help='number of interleaved instances of the algorithm to make') 19 | cmd.add_argument('--scale-remote', type=int, default=1, help='number of interleaved instances of the algorithm to make more for IB') 20 | cmd.add_argument('--prefix', type=str, default="", help='prefix to add to xmlfile') 21 | 22 | 23 | 24 | def handle(args, command): 25 | if command != 'ncclize': 26 | return False 27 | 28 | input_algorithms = read_algorithm(args) 29 | validate_output_args(args) 30 | 31 | args.old_format = True 32 | args.use_scratch = True 33 | args.aid_IB_contig = True 34 | 35 | for algo in input_algorithms: 36 | ncclized = ncclize(algo, 37 | remap_scratch=args.remap_scratch, 38 | channel_policy=args.channel_policy, 39 | pretty_print=not args.no_pretty_print, 40 | old_format=args.old_format, 41 | use_scratch=args.use_scratch, 42 | merge_contiguous=not args.no_merge_contiguous, 43 | instances=args.instances, 44 | scale_remote=args.scale_remote, 45 | combine_contig=args.extra_contig, 46 | aid_IB_contig=args.aid_IB_contig, 47 | prefix=args.prefix, 48 | logging=True) 49 | 50 | algo_name = algo.name.replace("[",".") 51 | algo_name = algo_name.replace("]","") 52 | algo_name = algo_name.replace(" ","") 53 | suffix = "" 54 | if args.extra_contig: 55 | suffix += "_extraContig" 56 | if args.aid_IB_contig: 57 | suffix += "_IBContig" 58 | handled = output_handler(args, lambda: ncclized, name_sccl_object(algo_name + f"_i{args.instances}_scRemote{args.scale_remote}{suffix}{args.prefix}", ending='sccl.xml')) 59 | 60 | return True 61 | 62 | return handle 63 | -------------------------------------------------------------------------------- /taccl/cli/solve.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import argparse 5 | import numpy as np 6 | import os 7 | from taccl.routing import TACCLRouting 8 | from taccl.heuristic_ordering import HeuristicOrderer 9 | from taccl.scheduler import TACCLScheduler 10 | from taccl.reduce_scheduler import TACCLRevScheduler 11 | from .known_collectives import KnownCollectives 12 | from .known_topologies import KnownTopologies 13 | from .common import * 14 | 15 | def optimize_comm_sketch(topology, route_sketch, collective, distribute_over_links=False): 16 | path_encoder = TACCLRouting(topology, route_sketch, collective) 17 | orderer = HeuristicOrderer(topology, route_sketch, collective) 18 | scheduler = TACCLScheduler(topology, route_sketch, collective) 19 | 20 | chunk_send, time_send, chunk_recv, time_recv = path_encoder.optimize(distribute_over_links) 21 | time_recv, chunk_recv, switch_time_recv, switch_chunk_recv, switch_time_send, switch_chunk_send, nic_time_recv, nic_chunk_recv, nic_time_send, nic_chunk_send, switch_link_mapping_recv, switch_link_mapping_send, _ = orderer.perform_ordering( 22 | chunk_send, time_send, chunk_recv, time_recv 23 | ) 24 | cont_algo = scheduler.optimize(chunk_recv, time_recv, switch_chunk_recv, switch_time_recv, switch_chunk_send, switch_time_send, nic_chunk_recv, nic_time_recv, nic_chunk_send, nic_time_send, switch_link_mapping_recv, switch_link_mapping_send) 25 | return cont_algo 26 | 27 | 28 | def check_heur_comm_sketch(topology, route_sketch, collective, ts_heur): 29 | path_encoder = TACCLRouting(topology, route_sketch, collective) 30 | orderer = HeuristicOrderer(topology, route_sketch, collective) 31 | scheduler = TACCLScheduler(topology, route_sketch, collective) 32 | 33 | chunk_send, time_send, chunk_recv, time_recv = path_encoder.check_heuristic(ts_heur) 34 | time_recv, chunk_recv, switch_time_recv, switch_chunk_recv, switch_time_send, switch_chunk_send, nic_time_recv, nic_chunk_recv, nic_time_send, nic_chunk_send, switch_link_mapping_recv, switch_link_mapping_send, _ = orderer.perform_ordering( 35 | chunk_send, time_send, chunk_recv, time_recv 36 | ) 37 | cont_algo = scheduler.optimize(chunk_recv, time_recv, switch_chunk_recv, switch_time_recv, switch_chunk_send, switch_time_send, nic_chunk_recv, nic_time_recv, nic_chunk_send, nic_time_send, switch_link_mapping_recv, switch_link_mapping_send) 38 | return cont_algo 39 | 40 | def get_send_dict_base(ts=""): 41 | assert len(ts) 42 | return np.load(f"send_dict_{ts}.npy", allow_pickle=True).item() 43 | 44 | def process_dict(send_dict_base, topology, collective): 45 | C = collective.num_chunks 46 | R = collective.num_nodes 47 | L = topology.L 48 | 49 | time_recv = [[[[] for l in range(L)] for src in range(R)] for r in range(R)] 50 | chunk_recv = [[[[] for l in range(L)] for src in range(R)] for r in range(R)] 51 | time_send = [[[[] for l in range(L)] for src in range(R)] for r in range(R)] 52 | chunk_send = [[[[] for l in range(L)] for src in range(R)] for r in range(R)] 53 | 54 | for t in send_dict_base: 55 | for (c,src,r,t_,l) in send_dict_base[t]: 56 | chunk_send[src][r][l].append(c) 57 | time_send[src][r][l].append(t_) 58 | chunk_recv[r][src][l].append(c) 59 | time_recv[r][src][l].append(t_ + topology.get_invbw(src,r)) 60 | return chunk_send, time_send, chunk_recv, time_recv 61 | 62 | def optimize_reduction(reduce_coll, topology, route_sketch, collective, ts, prefer_local_reduce_first=False): 63 | orderer = HeuristicOrderer(topology, route_sketch, collective, reverse=True) 64 | scheduler = TACCLRevScheduler(topology, route_sketch, collective) 65 | 66 | send_dict_base = get_send_dict_base(ts) 67 | chunk_send, time_send, chunk_recv, time_recv = process_dict(send_dict_base, topology, collective) 68 | 69 | # heuristic = 12 in routesketch will reverse the chunk order 70 | time_recv, chunk_order,switch_time_recv, switch_chunk_recv, switch_time_send, switch_chunk_send, nic_time_recv, nic_chunk_recv, nic_time_send, nic_chunk_send, switch_link_mapping_recv, switch_link_mapping_send, paths = orderer.perform_ordering(chunk_send, time_send, chunk_recv, time_recv) 71 | for r in range(collective.num_nodes): 72 | for ll in range(len(switch_chunk_recv[r])): 73 | print("new_swt_recv: ", r, ll, switch_chunk_recv[r][ll]) 74 | for r1 in range(len(chunk_order)): 75 | for r2 in range(len(chunk_order[r1])): 76 | for l in range(len(chunk_order[r1][r2])): 77 | if len(chunk_order[r1][r2][l]): 78 | print("old_send_order", r1, r2, chunk_recv[r2][r1][l]) 79 | print("new_send_order", r2, r1, chunk_order[r1][r2][l]) 80 | 81 | ordered_send_dict_reverse = scheduler.optimize_reversed(chunk_order, time_recv, switch_chunk_recv, switch_time_recv, switch_chunk_send, switch_time_send, nic_chunk_recv, nic_time_recv, nic_chunk_send, nic_time_send, switch_link_mapping_recv, switch_link_mapping_send, paths) 82 | np.save(f'send_dict_redscat_{ts}.npy', ordered_send_dict_reverse) 83 | 84 | cont_algo = scheduler.build_allreduce(reduce_coll,ordered_send_dict_reverse, send_dict_base, ts) 85 | 86 | return cont_algo 87 | 88 | def make_handle_solve_comm_sketch(cmd_parsers): 89 | name = 'solve' 90 | cmd = cmd_parsers.add_parser(name) 91 | topologies = KnownTopologies(cmd) 92 | collectives = KnownCollectives(cmd) 93 | validate_output_args, output_handler = add_output_sccl_objects(cmd) 94 | # cmd.add_argument('--topo-file', type=argparse.FileType('r')) 95 | cmd.add_argument('--sketch-file', type=argparse.FileType('r')) 96 | # cmd.add_argument('--topo-name', type=str) 97 | cmd.add_argument('--ts-heur', type=int, default="-1") 98 | def handle(args, command): 99 | if command != name: 100 | return False 101 | 102 | validate_output_args(args) 103 | node_topology = topologies.create(args) 104 | topology, route_sketch = parse_and_get_topo(node_topology, args.sketch_file) 105 | collective = collectives.create(args, topology.num_nodes()).chunk_up(route_sketch.hyperparameters.chunkup) 106 | ts_heur = args.ts_heur 107 | if ts_heur == -1: 108 | algo = optimize_comm_sketch(topology, route_sketch, collective) 109 | else: 110 | algo = check_heur_comm_sketch(topology, route_sketch, collective, ts_heur) 111 | output_handler(args, algo, algo.name + "_taccl") 112 | return True 113 | 114 | return handle 115 | 116 | def make_handle_combine_comm_sketch(cmd_parsers): 117 | name = 'combine' 118 | cmd = cmd_parsers.add_parser(name) 119 | topologies = KnownTopologies(cmd) 120 | collectives = KnownCollectives(cmd) 121 | validate_output_args, output_handler = add_output_sccl_objects(cmd) 122 | cmd.add_argument('--sketch-file', type=str, default=None) 123 | cmd.add_argument('--ts', type=str, help='timestamp of send_dict for Allgather') 124 | cmd.add_argument('--prefer-local-reduce-first', action='store_true', help='should prefer reducing a chunk locally first if it is the same either way') 125 | def handle(args, command): 126 | if command != name: 127 | return False 128 | if args.sketch_file is None: 129 | cmd_parsers.error('Must specify sketch file') 130 | 131 | assert os.path.isfile(args.sketch_file), "sketch file does not exist" 132 | sketch_file = open(args.sketch_file, 'r') 133 | 134 | validate_output_args(args) 135 | node_topology = topologies.create(args) 136 | topology, route_sketch = parse_and_get_topo(node_topology, sketch_file, reduce=True) 137 | collective = collectives.create(args, topology.num_nodes()).chunk_up(route_sketch.hyperparameters.chunkup) 138 | 139 | import copy 140 | new_args = copy.deepcopy(args) 141 | # new_args.collective = 'Allgather' 142 | # new_args.collective = 'ReduceScatter' 143 | new_args.collective = 'Allreduce' 144 | allreduce_coll = collectives.create(new_args, topology.num_nodes()).chunk_up(route_sketch.hyperparameters.chunkup) 145 | algo = optimize_reduction(allreduce_coll, topology, route_sketch, collective, args.ts, args.prefer_local_reduce_first) 146 | output_handler(args, algo, algo.name + "_taccl") 147 | return True 148 | 149 | return handle -------------------------------------------------------------------------------- /taccl/collectives.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from abc import ABC, abstractmethod 5 | from dataclasses import dataclass 6 | 7 | @dataclass 8 | class Chunk: 9 | precondition: set 10 | postcondition: set 11 | address: int 12 | 13 | @dataclass 14 | class Rank: 15 | precondition: set 16 | postcondition: set 17 | id: int 18 | 19 | class Collective: 20 | def __init__(self, name, num_nodes, chunks, ranks=None, triggers = {}): 21 | self.name = name 22 | self.num_nodes = num_nodes 23 | assert ranks is not None 24 | assert len(ranks) == num_nodes 25 | self.num_chunks = len(chunks) 26 | self._chunks = chunks 27 | self._ranks = ranks 28 | self._triggers = triggers 29 | 30 | self.is_combining = False 31 | addresses_seen = set() 32 | for chunk in self._chunks: 33 | if chunk.address in addresses_seen: 34 | self.is_combining = True 35 | addresses_seen.add(chunk.address) 36 | self.num_addresses = len(addresses_seen) 37 | 38 | def ranks(self): 39 | return range(self.num_nodes) 40 | 41 | def chunks(self): 42 | return range(len(self._chunks)) 43 | 44 | def precondition(self, rank, chunk): 45 | return rank in self._chunks[chunk].precondition 46 | 47 | def postcondition(self, rank, chunk): 48 | return rank in self._chunks[chunk].postcondition 49 | 50 | def address(self, chunk): 51 | return self._chunks[chunk].address 52 | 53 | def trigger(self, rank, chunk): 54 | if (rank, chunk) in self._triggers: 55 | return self._triggers[(rank, chunk)] 56 | else: 57 | return None 58 | 59 | def pre_rank(self, chunk): 60 | return self._chunks[chunk].precondition 61 | 62 | def post_rank(self, chunk): 63 | return self._chunks[chunk].postcondition 64 | 65 | def pre_chunk(self, rank): 66 | return self._ranks[rank].precondition 67 | 68 | def post_chunk(self, rank): 69 | return self._ranks[rank].postcondition 70 | 71 | def has_triggers(self): 72 | return len(self._triggers) > 0 73 | 74 | def chunk_up(self, div): 75 | if div < 1: 76 | raise ValueError('Divisor must be greater or equal to one (and one is a no-op).') 77 | if div == 1: 78 | return self 79 | 80 | def remap(addr, i): 81 | return addr * div + i 82 | 83 | new_chunks = [] 84 | new_ranks = [] 85 | for chunk in self._chunks: 86 | for i in range(div): 87 | new_chunks.append(Chunk(chunk.precondition, chunk.postcondition, remap(chunk.address, i))) 88 | for rank in self._ranks: 89 | new_rank_precondition = set(remap(chunk, i) for i in range(div) for chunk in rank.precondition) 90 | new_rank_postcondition = set(remap(chunk, i) for i in range(div) for chunk in rank.postcondition) 91 | new_ranks.append(Rank(new_rank_precondition, new_rank_postcondition, rank)) 92 | 93 | name = f'{self.name},chunks={div}' 94 | return Collective(name, self.num_nodes, new_chunks, new_ranks) 95 | 96 | def __str__(self): 97 | collstr = "{}\n \t num_nodes: {}\n \t num_chunks: {}".format(self.name, self.num_nodes, self.num_chunks) 98 | return collstr 99 | 100 | def build_collective(name, num_nodes, num_chunks, precondition, postcondition, address = lambda c: c, trigger = lambda r, c: None): 101 | chunks = [] 102 | ranks = [] 103 | for chunk in range(num_chunks): 104 | chunk_precondition = set(rank for rank in range(num_nodes) if precondition(rank, chunk)) 105 | chunk_postcondition = set(rank for rank in range(num_nodes) if postcondition(rank, chunk)) 106 | chunk_address = address(chunk) 107 | chunks.append(Chunk(chunk_precondition, chunk_postcondition, chunk_address)) 108 | for rank in range(num_nodes): 109 | rank_precondition = set(chunk for chunk in range(num_chunks) if precondition(rank, chunk)) 110 | rank_postcondition = set(chunk for chunk in range(num_chunks) if postcondition(rank, chunk)) 111 | ranks.append(Rank(rank_precondition, rank_postcondition, rank)) 112 | triggers = {(rank, chunk): trigger(rank, chunk) for rank in range(num_nodes) for chunk in range(num_chunks) if trigger(rank, chunk) != None} 113 | return Collective(name, num_nodes, chunks, ranks, triggers) 114 | 115 | # Common pre- and postconditions 116 | def _scattered(num_nodes, chunks = 1): 117 | def cond(rank, chunk): 118 | return rank == (chunk // chunks) % num_nodes 119 | return cond 120 | 121 | def _transpose(num_nodes): 122 | def cond(rank, chunk): 123 | return rank == chunk // num_nodes 124 | return cond 125 | 126 | def _all(rank, chunk): 127 | return True 128 | 129 | def _root(root): 130 | def cond(rank, chunk): 131 | return rank == root 132 | return cond 133 | 134 | # Non-combining collectives 135 | 136 | def broadcast(num_nodes, root): 137 | return build_collective(f'Broadcast(n={num_nodes},root={root})', num_nodes, 1, _root(root), _all) 138 | 139 | def scatter(num_nodes, root): 140 | return build_collective(f'Scatter(n={num_nodes},root={root})', num_nodes, num_nodes, _root(root), _scattered(num_nodes)) 141 | 142 | def gather(num_nodes, root): 143 | return build_collective(f'Gather(n={num_nodes},root={root})', num_nodes, num_nodes, _scattered(num_nodes), _root(root)) 144 | 145 | def allgather(num_nodes): 146 | return build_collective(f'Allgather(n={num_nodes})', num_nodes, num_nodes, _scattered(num_nodes), _all) 147 | 148 | def alltoall(num_nodes): 149 | return build_collective(f'Alltoall(n={num_nodes})', num_nodes, num_nodes * num_nodes, _scattered(num_nodes), _transpose(num_nodes)) 150 | 151 | # Combining collectives 152 | 153 | # Represents a single buffer to reduce 154 | def _single_scattered(num_nodes): 155 | def address(chunk): 156 | return chunk // num_nodes 157 | return address 158 | 159 | def reduce(num_nodes, root): 160 | return build_collective(f'Reduce(n={num_nodes},root={root})', num_nodes, num_nodes, _scattered(num_nodes), _root(root), _single_scattered(num_nodes)) 161 | 162 | def allreduce(num_nodes): 163 | return build_collective(f'Allreduce(n={num_nodes})', num_nodes, num_nodes, _scattered(num_nodes), _all, _single_scattered(num_nodes)) 164 | 165 | def reduce_scatter(num_nodes): 166 | return build_collective(f'ReduceScatter(n={num_nodes})', num_nodes, num_nodes * num_nodes, _scattered(num_nodes), _transpose(num_nodes), _single_scattered(num_nodes)) 167 | 168 | def scan(num_nodes): 169 | def postcondition(rank, chunk): 170 | origin = chunk % num_nodes 171 | return rank >= origin 172 | return build_collective(f'Scan(n={num_nodes})', num_nodes, num_nodes, _scattered(num_nodes), postcondition, _single_scattered(num_nodes)) 173 | 174 | # Multi-root generalizations of MPI rooted collectives 175 | # TODO: Add one for reduce. That needs a new addressing function. 176 | 177 | def _roots(roots): 178 | def cond(rank, chunk): 179 | return rank == roots[chunk % len(roots)] 180 | return cond 181 | 182 | def multiroot_broadcast(num_nodes, roots): 183 | return build_collective(f'MultirootBroadcast(n={num_nodes},roots=({",".join(str(i) for i in roots)}))', num_nodes, len(roots), _roots(roots), _all) 184 | 185 | def multiroot_scatter(num_nodes, roots): 186 | return build_collective(f'MultirootScatter(n={num_nodes},roots=({",".join(str(i) for i in roots)}))', num_nodes, num_nodes * len(roots), _roots(roots), _scattered(num_nodes, len(roots))) 187 | 188 | def multiroot_gather(num_nodes, roots): 189 | return build_collective(f'MultirootGather(n={num_nodes},roots=({",".join(str(i) for i in roots)}))', num_nodes, num_nodes * len(roots), _scattered(num_nodes, len(roots)), _roots(roots)) 190 | -------------------------------------------------------------------------------- /taccl/examples/sketch/sk-dgx2-n1.json: -------------------------------------------------------------------------------- 1 | { 2 | "nnodes": 1, 3 | "intranode_sketch": { 4 | "strategy": "switch", 5 | "switches": [[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]], 6 | "switch_hyperedge_strategy": ["uc-min"] 7 | }, 8 | "symmetry_offsets": [[2, 16]], 9 | "hyperparameters": { 10 | "input_chunkup": 2 11 | } 12 | } -------------------------------------------------------------------------------- /taccl/examples/sketch/sk-ndv2-n1-cUp6.json: -------------------------------------------------------------------------------- 1 | { 2 | "nnodes": 1, 3 | "intranode_sketch": { 4 | "strategy": "none" 5 | }, 6 | "symmetry_offsets": [], 7 | "hyperparameters": { 8 | "input_chunkup": 6 9 | } 10 | } -------------------------------------------------------------------------------- /taccl/examples/sketch/sk1-dgx2-n1.json: -------------------------------------------------------------------------------- 1 | { 2 | "nnodes": 1, 3 | "intranode_sketch": { 4 | "strategy": "switch", 5 | "switches": [[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]], 6 | "switch_hyperedge_strategy": ["uc-min"] 7 | }, 8 | "symmetry_offsets": [[2, 16]], 9 | "hyperparameters": { 10 | "input_chunkup": 1 11 | } 12 | } -------------------------------------------------------------------------------- /taccl/examples/sketch/sk1-dgx2-n2.json: -------------------------------------------------------------------------------- 1 | { 2 | "nnodes": 2, 3 | "intranode_sketch": { 4 | "strategy": "switch", 5 | "switches": [[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]], 6 | "switch_hyperedge_strategy": ["uc-min"] 7 | }, 8 | "internode_sketch": { 9 | "strategy": "relay", 10 | "internode_conn": {"1" : [0], "3" : [2], "5" : [4], "7" : [6], "9" : [8], "11" : [10], "13" : [12], "15" : [14]}, 11 | "gpus_to_sender_rev_map": { 12 | "1" : [0, 1], 13 | "3" : [2, 3], 14 | "5" : [4, 5], 15 | "7" : [6, 7], 16 | "9" : [8, 9], 17 | "11" : [10, 11], 18 | "13" : [12, 13], 19 | "15" : [14, 15] 20 | }, 21 | "enforce_ordering": true 22 | }, 23 | "symmetry_offsets": [[2, 16], [16, 32]], 24 | "hyperparameters": { 25 | "input_chunkup": 2 26 | } 27 | } -------------------------------------------------------------------------------- /taccl/examples/sketch/sk1-ndv2-n1.json: -------------------------------------------------------------------------------- 1 | { 2 | "nnodes": 1, 3 | "intranode_sketch": { 4 | "strategy": "none" 5 | }, 6 | "symmetry_offsets": [], 7 | "hyperparameters": { 8 | "input_chunkup": 1 9 | } 10 | } -------------------------------------------------------------------------------- /taccl/examples/sketch/sk1-ndv2-n2.json: -------------------------------------------------------------------------------- 1 | { 2 | "nnodes": 2, 3 | "intranode_sketch": { 4 | "strategy": "none" 5 | }, 6 | "internode_sketch": { 7 | "strategy": "relay", 8 | "internode_conn": {"0" : [1]} 9 | }, 10 | "symmetry_offsets": [[8,16]], 11 | "hyperparameters": { 12 | "input_chunkup": 1 13 | } 14 | } -------------------------------------------------------------------------------- /taccl/examples/sketch/sk1-ndv2-n4.json: -------------------------------------------------------------------------------- 1 | { 2 | "nnodes": 4, 3 | "intranode_sketch": { 4 | "strategy": "none" 5 | }, 6 | "internode_sketch": { 7 | "strategy": "relay", 8 | "internode_conn": {"0" : [1]} 9 | }, 10 | "symmetry_offsets": [[8,32]], 11 | "hyperparameters": { 12 | "input_chunkup": 1 13 | } 14 | } -------------------------------------------------------------------------------- /taccl/examples/sketch/sk2-dgx2-n2.json: -------------------------------------------------------------------------------- 1 | { 2 | "nnodes": 2, 3 | "intranode_sketch": { 4 | "strategy": "switch", 5 | "switches": [[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]], 6 | "switch_hyperedge_strategy": ["uc-max"] 7 | }, 8 | "internode_sketch": { 9 | "strategy": "relay", 10 | "internode_conn": "direct-map", 11 | "gpus_to_sender_rev_map": { 12 | "0" : [0], 13 | "1" : [1], 14 | "2" : [2], 15 | "3" : [3], 16 | "4" : [4], 17 | "5" : [5], 18 | "6" : [6], 19 | "7" : [7], 20 | "8" : [8], 21 | "9" : [9], 22 | "10" : [10], 23 | "11" : [11], 24 | "12" : [12], 25 | "13" : [13], 26 | "14" : [14], 27 | "15" : [15] 28 | } 29 | }, 30 | "symmetry_offsets": [[2, 16], [16, 32]], 31 | "hyperparameters": { 32 | "input_chunkup": 1 33 | } 34 | } -------------------------------------------------------------------------------- /taccl/examples/sketch/sk2-ndv2-n2.json: -------------------------------------------------------------------------------- 1 | { 2 | "nnodes": 2, 3 | "intranode_sketch": { 4 | "strategy": "none" 5 | }, 6 | "internode_sketch": { 7 | "strategy": "relay", 8 | "internode_conn": "fully-connected" 9 | }, 10 | "symmetry_offsets": [[8,16]], 11 | "hyperparameters": { 12 | "input_chunkup": 1 13 | } 14 | } -------------------------------------------------------------------------------- /taccl/examples/sketch/sk2-ndv2-n4.json: -------------------------------------------------------------------------------- 1 | { 2 | "nnodes": 4, 3 | "intranode_sketch": { 4 | "strategy": "none" 5 | }, 6 | "internode_sketch": { 7 | "strategy": "relay", 8 | "internode_conn": "fully-connected" 9 | }, 10 | "symmetry_offsets": [[8,32]], 11 | "hyperparameters": { 12 | "input_chunkup": 1 13 | } 14 | } -------------------------------------------------------------------------------- /taccl/examples/sketch/sk3-ndv2-n2.json: -------------------------------------------------------------------------------- 1 | { 2 | "nnodes": 2, 3 | "intranode_sketch": { 4 | "strategy": "none" 5 | }, 6 | "internode_sketch": { 7 | "strategy": "relay", 8 | "internode_conn": "direct-map" 9 | }, 10 | "symmetry_offsets": [[8,16]], 11 | "hyperparameters": { 12 | "input_chunkup": 1 13 | } 14 | } -------------------------------------------------------------------------------- /taccl/examples/topo/topo-dgx2-1KB.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "DGX2_1KB", 3 | "gpus_per_node": 16, 4 | "nics_per_node": 8, 5 | "alpha": 30, 6 | "node_betas_list": [1], 7 | "node_invbws_list": [31], 8 | "remote_invbw": 270, 9 | "remote_alpha": 260, 10 | "remote_beta": 10.25 11 | } -------------------------------------------------------------------------------- /taccl/examples/topo/topo-dgx2-1MB.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "DGX2_1MB", 3 | "gpus_per_node": 16, 4 | "nics_per_node": 8, 5 | "alpha": 0.3, 6 | "node_betas_list": [8], 7 | "node_invbws_list": [8], 8 | "remote_invbw": 107, 9 | "remote_alpha": 2.6, 10 | "remote_beta": 105 11 | } -------------------------------------------------------------------------------- /taccl/examples/topo/topo-ndv2-1KB.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "NDv2_1KB", 3 | "gpus_per_node": 8, 4 | "nics_per_node": 1, 5 | "alpha": 30, 6 | "node_betas_list": [2.24, 4.50], 7 | "node_invbws_list": [32.25, 34.50], 8 | "remote_alpha": 260, 9 | "remote_beta": 10.25, 10 | "remote_invbw": 270.25 11 | } -------------------------------------------------------------------------------- /taccl/examples/topo/topo-ndv2-1MB.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "NDv2_1MB", 3 | "gpus_per_node": 8, 4 | "nics_per_node": 1, 5 | "alpha": 0.3, 6 | "node_betas_list": [23, 46], 7 | "node_invbws_list": [23, 46], 8 | "remote_invbw": 107, 9 | "remote_alpha": 2.6, 10 | "remote_beta": 105 11 | } -------------------------------------------------------------------------------- /taccl/examples/topo/topo-ndv2-32KB.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "NDv2_32KB", 3 | "gpus_per_node": 8, 4 | "nics_per_node": 1, 5 | "alpha": 3, 6 | "node_betas_list": [7.18, 14.38], 7 | "node_invbws_list": [10.18, 17.38], 8 | "remote_alpha": 26, 9 | "remote_beta": 32.8, 10 | "remote_invbw": 59 11 | } -------------------------------------------------------------------------------- /taccl/instance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from dataclasses import dataclass 5 | 6 | @dataclass(frozen=True) 7 | class Instance: 8 | steps: int 9 | extra_rounds: int = 0 10 | chunks: int = 1 11 | pipeline: int = None 12 | extra_memory: int = None 13 | allow_exchange: bool = False 14 | 15 | def rounds(self): 16 | return self.steps + self.extra_rounds 17 | 18 | def set(self, steps = None, extra_rounds = None, chunks = None, pipeline = None, extra_memory = None, allow_exchange = None): 19 | return Instance( 20 | steps if steps != None else self.steps, 21 | extra_rounds if extra_rounds != None else self.extra_rounds, 22 | chunks if chunks != None else self.chunks, 23 | pipeline if pipeline != None else self.pipeline, 24 | extra_memory if extra_memory != None else self.extra_memory, 25 | allow_exchange if allow_exchange != None else self.allow_exchange) 26 | 27 | def __str__(self): 28 | s = f'steps={self.steps}' 29 | if self.extra_rounds > 0: 30 | s += f',rounds={self.steps + self.extra_rounds}' 31 | if self.chunks > 1: 32 | s += f',chunks={self.chunks}' 33 | if self.pipeline != None: 34 | s += f',pipeline={self.pipeline}' 35 | if self.extra_memory != None: 36 | s += f',extra_memory={self.extra_memory}' 37 | if self.allow_exchange: 38 | s += f',allow_exchange' 39 | return s 40 | -------------------------------------------------------------------------------- /taccl/reduce_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from taccl import topologies 5 | from taccl.algorithm import * 6 | from taccl.heuristic_ordering import HeuristicOrderer 7 | from taccl.instance import * 8 | from taccl.shortest_path_sets import * 9 | from gurobipy import GRB, Model, quicksum, abs_, and_ 10 | from taccl.utils import * 11 | import numpy as np 12 | 13 | class TACCLRevScheduler(object): 14 | def __init__(self, topology, route_sketch, collective): 15 | self.topology = topology 16 | self.route_sketch = route_sketch 17 | self.collective = collective 18 | 19 | def latency(self, src, dst, l): 20 | return self.topology.get_invbw(src,dst) 21 | 22 | def _is_relay_link(self,r,dst): 23 | if self.topology.gpu_to_node(r) != self.topology.gpu_to_node(dst): 24 | return True 25 | return False 26 | 27 | def _encode(self, opt, chunk_order, chunk_time, 28 | switch_chunk_order_recv, switch_chunk_time_recv, switch_chunk_order_send, switch_chunk_time_send, 29 | nic_chunk_order_recv, nic_chunk_time_recv, nic_chunk_order_send, nic_chunk_time_send, 30 | switch_link_mapping_recv=None, switch_link_mapping_send=None, 31 | endpoints_rc=[], prefer_local_reduce_first=True, extra_heuristic=True): 32 | 33 | # self.spsets = shortest_path_sets(self.topology, self.collective) 34 | heuristic = self.route_sketch.hyperparameters.heuristic 35 | self.chunkup = self.route_sketch.hyperparameters.chunkup 36 | C = self.collective.num_chunks 37 | R = self.collective.num_nodes 38 | L = self.topology.L 39 | smallM = 10 40 | M = 10000000 # big-M for maximum self.time between sends 41 | ST = 5000000 # self.time for unsent sends and unstarted starts 42 | SND = 10000000 # self.time for unsent sends and unstarted starts 43 | opt.Params.Threads = 4 44 | opt.Params.IntegralityFocus = 1 45 | 46 | self.is_sent_set_1 = set() 47 | self.is_before_set_1 = set() # Fixed ordering of chunks over a link 48 | self.is_together_set_0 = set() # Fix if chunks are received together on a GPU 49 | self.is_together_set_1 = set() 50 | self.recv_first_set_1 = set() # Fixed ordering between recvs on a switch 51 | self.nic_recv_first_set_1 = set() # Fixed ordering between recvs on a NIC 52 | self.send_first_set_1 = set() # Fixed ordering between sends on a switch 53 | self.nic_send_first_set_1 = set() # Fixed ordering between sends on a NIC 54 | 55 | self.is_before = {} # (c,o,r): c is received on GPU r before o, but from same source 56 | self.is_together = {} # (c,o,r): c dst only if r -> dst is an IB 77 | def _should_try_together(r,dst,c,o): 78 | if (self.topology.copies <= 1): 79 | return False 80 | assert r != dst 81 | if self._is_relay_link(r,dst): 82 | # for allgather, not contig unless they are the input of the same GPU 83 | if "Allgather" in self.collective.name and (c//self.chunkup != o//self.chunkup): 84 | return False 85 | return True 86 | return False 87 | 88 | # Can fix contiguous sends if reqd 89 | def _should_fix_together(r,dst,c,o): 90 | return False 91 | if not (isinstance(self.topology, DistributedTopology) and self.topology.m_top == MachineTopology.RELAYED): 92 | return False 93 | for rc in self.collective.pre_on(c): 94 | r1 = rc 95 | for ro in self.collective.pre_on(o): 96 | r2 = ro 97 | assert r != dst 98 | if self._is_relay_link(r,dst): 99 | if self.topology.bw_dist[rc][r] == self.topology.bw_dist[ro][r]: 100 | return True 101 | return False 102 | 103 | # Populate is_sent_set_1 from the chunk_order received from path encoding 104 | def _add_chunk_sent(opt, heuristic): 105 | if chunk_order is not None: 106 | assert chunk_time is not None 107 | assert len(chunk_order) == R 108 | assert len(chunk_order[0]) == R 109 | assert len(chunk_order[0][0]) <= L 110 | # TODO: we will do distribution at heuristic_ordering.py 111 | # dist_link_heuristic = [3,5,8,9,10,11,13,12] # Distribute chunk sends if there are multiple links connecting src to r 112 | for r in range(R): 113 | for src in self.topology.sources(r): 114 | for l in range(self.topology.link(src,r)): 115 | for c in chunk_order[r][src][l]: 116 | self.is_sent_set_1.add((c,src,r,l)) 117 | 118 | def _add_switch_order(switch_chunk_order_recv, switch_chunk_order_send, switch_link_mapping_recv, switch_link_mapping_send): 119 | # Order recvs coming into and going out from a GPU connected to a switch 120 | # recv_right_after[r][(c,srci)] = (o,srcj) => GPU r receives o from srcj right after receiving c from srci 121 | # send_right_after[r][(c,dsti)] = (o,dstj) => GPU r sends o to dstj right after sending c to dsti 122 | # (c,o,r,srci,srcj) \in recv_first_set_1 => c is recvd on r from srci anytime before o is recvd on r from srcj 123 | # (c,o,ri,dsti,dstj) \in send_first_set_1 => c is sent from r to dsti anytime before o is sent from r to dstj 124 | LL = 0 125 | for r in range(R): 126 | LL = max(LL, len(switch_chunk_order_recv[r])) 127 | LL = max(LL, len(switch_chunk_order_send[r])) 128 | 129 | recv_right_after, recv_first_set_1, send_right_after, send_first_set_1 = add_switch_order(switch_chunk_order_recv, 130 | switch_chunk_order_send, 131 | switch_link_mapping_recv, 132 | switch_link_mapping_send, R, LL) 133 | for recv in recv_first_set_1: 134 | self.recv_first_set_1.add(recv) 135 | for send in send_first_set_1: 136 | self.send_first_set_1.add(send) 137 | return recv_right_after, send_right_after 138 | 139 | def _add_chunk_order(opt, heuristic, recv_right_after, send_right_after): 140 | # dist_link_heuristic = [3,5,8,9,10,12, 13] # Distribute chunk sends if there are multiple links connecting src to r 141 | if chunk_order is not None: 142 | assert chunk_time is not None 143 | assert len(chunk_order) == R 144 | assert len(chunk_order[0]) == R 145 | for r in range(R): 146 | for src in self.topology.sources(r): 147 | if len(chunk_order[r][src][0]): 148 | print("chunk_order", r, src, chunk_order[r][src][0]) 149 | for l in range(self.topology.link(src,r)): 150 | this_chunk_order = chunk_order[r][src][l] 151 | max_contig = 6 152 | for i, c in enumerate(this_chunk_order): 153 | j = i + 1 154 | while j 0 164 | if (recv_right_after[r][ll][(c,src)] != recv_right_after[r][ll][(o,src)] or send_right_after[src][ll_src][(c,r)] != send_right_after[src][ll_src][(o,r)]): 165 | self.is_together_set_0.add((c1,o1,r,src)) 166 | self.is_before_set_1.add((c,o,r,src)) 167 | skip_others = True 168 | if not skip_others: 169 | if _should_fix_together(src,r,c,o): 170 | self.is_together_set_1.add((c1,o1,r,src)) 171 | # Max contiguity allowed = 6 172 | elif _should_try_together(src,r,c,o) and j-i{r})') 174 | is_before_ocr = 0 175 | if not extra_heuristic: 176 | if (o,c,r,src) not in self.is_before: 177 | self.is_before[(o,c,r,src)] = opt.addVar(vtype=GRB.BINARY) 178 | is_before_ocr = self.is_before[(o,c,r,src)] 179 | else: 180 | assert (o,c,r,src) not in self.is_before 181 | if (c,o,r,src) not in self.is_before: 182 | self.is_before[(c,o,r,src)] = opt.addVar(vtype=GRB.BINARY) 183 | if (c1,o1,r) not in self.is_together: 184 | self.is_together[(c1,o1,r,src)] = opt.addVar(vtype=GRB.BINARY) 185 | opt.addLConstr(self.is_before[(c,o,r,src)] + self.is_together[(c1,o1,r,src)] + is_before_ocr == 1) 186 | # send chunk together with another only if the previous chunk between the two has been sent together 187 | if j-1>i: 188 | opt.addLConstr(self.is_together[(c1,o1,r,src)] <= self.is_together[(c2,prev_o2,r,src)]) 189 | else: 190 | self.is_together_set_0.add((c1,o1,r,src)) 191 | self.is_before_set_1.add((c,o,r,src)) 192 | # if c1 == 34 and o1 == 35: 193 | # print("2. added is_before ", c,o,r,src) 194 | j = j + 1 195 | i = i + 1 196 | 197 | def alpha(r,dst): 198 | assert r != dst 199 | if self._is_relay_link(r,dst): 200 | alpha = self.topology.remote_alpha 201 | assert alpha is not None 202 | return alpha 203 | return 0 204 | 205 | def beta(r,dst): 206 | assert r != dst 207 | if self._is_relay_link(r,dst): 208 | beta = self.topology.remote_beta 209 | assert beta is not None 210 | return beta 211 | return self.topology.get_invbw(r,dst) 212 | 213 | def calc_latency(src,r,l,c): 214 | if self._is_relay_link(src,r): 215 | num_s = 0 216 | for o in range(C): 217 | o1,c1 = minmax(o,c) 218 | if (o1,c1,r,src) in self.is_together_set_1: 219 | assert (o1,c1,r,src) not in self.is_together 220 | num_s = num_s + 1 221 | continue 222 | if (o1,c1,r,src) in self.is_together_set_0: 223 | assert (o1,c1,r,src) not in self.is_together 224 | else: 225 | if (o1,c1,r,src) not in self.is_together: 226 | self.is_together[(o1,c1,r,src)] = opt.addVar(vtype=GRB.BINARY) 227 | lat = alpha(src,r) + beta(src,r)*(num_s + quicksum(self.is_together[(o,c,r,src)] if (o,c,r,src) in self.is_together else 0 for o in range(c)) + quicksum(self.is_together[(c,o,r,src)] if (c,o,r,src) in self.is_together else 0 for o in range(c,C))) 228 | return lat 229 | return alpha(src,r) + beta(src,r) 230 | 231 | # Set chunk is_send_set 232 | _add_chunk_sent(opt, heuristic) 233 | 234 | # Populate values 235 | for c in self.collective.chunks(): 236 | for r in self.collective.ranks(): 237 | recvd_anytime = sum([sum([1 if (c,src,r,l) in self.is_sent_set_1 else 0 for l in range(L)]) for src in self.topology.sources(r)]) 238 | recv_IB = sum([sum([1 if (c,src,r,l) in self.is_sent_set_1 and self._is_relay_link(src,r) else 0 for l in range(L)]) for src in self.topology.sources(r)]) 239 | if recvd_anytime == 0: 240 | for srci in self.topology.sources(r): 241 | assert (c,c,r,srci) not in self.is_together_set_1 242 | assert (c,c,r,srci) not in self.is_together 243 | self.is_together_set_0.add((c,c,r,srci)) 244 | else: 245 | # Will receive a chunk at most once 246 | # assert recvd_anytime == 1 247 | for srci in self.topology.sources(r): 248 | assert (c,c,r,srci) not in self.is_together_set_1 249 | assert (c,c,r,srci) not in self.is_together 250 | self.is_together_set_1.add((c,c,r,srci)) 251 | 252 | # Set ordering 253 | should_add_switch_order = True 254 | recv_right_after = {} 255 | send_right_after = {} 256 | if should_add_switch_order: 257 | recv_right_after, send_right_after = _add_switch_order( 258 | switch_chunk_order_recv, 259 | switch_chunk_order_send, 260 | switch_link_mapping_recv, 261 | switch_link_mapping_send) 262 | 263 | _add_chunk_order(opt, heuristic, recv_right_after, send_right_after) 264 | 265 | # returns (is_static_val_cor, is_before_cor) 266 | def _get_isbefore(c,o,r,src): 267 | if (c,o,r,src) in self.is_before_set_1: 268 | return True, 1 269 | elif (c,o,r,src) in self.is_before: 270 | return False, self.is_before[(c,o,r,src)] 271 | else: 272 | return True, 0 273 | 274 | # returns (is_static_val_cor, is_together_cor) 275 | def _get_istogether(c,o,r,src): 276 | c1,o1 = minmax(c,o) 277 | if (c1,o1,r,src) in self.is_together_set_1: 278 | return True, 1 279 | elif (c1,o1,r,src) in self.is_together: 280 | return False, self.is_together[(c1,o1,r,src)] 281 | else: 282 | return True, 0 283 | 284 | print("endpoints_rc", endpoints_rc) 285 | # Correctness constraints 286 | self.weighted_terms_to_min = [] 287 | for r in self.collective.ranks(): 288 | src_r = [src for src in self.topology.sources(r)] 289 | links_r = {src: self.topology.link(src,r) for src in src_r} 290 | for c in self.collective.chunks(): 291 | opt.addLConstr(self.start[c,r] <= ST) 292 | if (r,c) in endpoints_rc: 293 | opt.addLConstr(self.start[c,r] == 0) 294 | else: 295 | # Bandwidth constraint 296 | for src in src_r: 297 | for l in range(links_r[src]): 298 | if (c,src,r,l) in self.is_sent_set_1: 299 | opt.addLConstr(self.start[c,r] >= self.send[c,src,r,l] + calc_latency(src,r,l,c)) 300 | else: 301 | opt.addLConstr(self.send[c,src,r,l] >= SND) 302 | for l in range(links_r[src], L): 303 | opt.addLConstr(self.send[c,src,r,l] == SND) 304 | recvd_anytime = sum([sum([1 if (c,src,r,l) in self.is_sent_set_1 else 0 for l in range(links_r[src])]) for src in src_r]) 305 | if self.collective.precondition(r, c): 306 | opt.addLConstr(self.start[c,r] <= self.time) 307 | else: 308 | if recvd_anytime == 0 and (r,c) not in endpoints_rc: 309 | print("setting >", c,r) 310 | opt.addLConstr(self.start[c,r] >= self.time + 1) 311 | else: 312 | opt.addLConstr(self.start[c,r] <= self.time) 313 | 314 | c_sources = [] 315 | for src in src_r: 316 | for l in range(links_r[src]): 317 | if (c,src,r,l) in self.is_sent_set_1: 318 | opt.addLConstr(self.start[c,src] <= self.start[c,r]) 319 | # c_sources.append((src,l)) 320 | c_sources.append(src) # NOTE assuming l == 0 always 321 | opt.addLConstr(self.start[c,src] <= self.send[c,src,r,l]) 322 | 323 | for i in range(len(c_sources)): 324 | for j in range(len(c_sources)): 325 | if i!=j: 326 | srci = c_sources[i] 327 | srcj = c_sources[j] 328 | srci1, srcj1 = minmax(srci,srcj) 329 | if (r,c,srci1,srcj1) not in self.is_reduce_before: 330 | self.is_reduce_before[(r,c,srci1,srcj1)] = opt.addVar(vtype=GRB.BINARY) 331 | if (r//num_local_nodes == srci1//num_local_nodes) and (r//num_local_nodes != srcj1//num_local_nodes): 332 | # try to reduce local nodes first (but only try) 333 | self.weighted_terms_to_min.append(-self.is_reduce_before[(r,c,srci1,srcj1)]) 334 | elif (r//num_local_nodes == srcj1//num_local_nodes) and (r//num_local_nodes != srci1//num_local_nodes): 335 | self.weighted_terms_to_min.append(self.is_reduce_before[(r,c,srci1,srcj1)]) 336 | 337 | opt.addGenConstrIndicator(self.is_reduce_before[(r,c,srci1,srcj1)], True, self.send[c,srcj1,r,0] >= self.send[c,srci1,r,0] + calc_latency(srci1,r,l,c)) 338 | opt.addGenConstrIndicator(self.is_reduce_before[(r,c,srci1,srcj1)], False, self.send[c,srci1,r,0] >= self.send[c,srcj1,r,0] + calc_latency(srcj1,r,l,c)) 339 | 340 | 341 | # Order sends from same gpu to same gpu 342 | for o in range(c): 343 | for src in src_r: 344 | is_static_cor, is_before_cor = _get_isbefore(c,o,r,src) 345 | is_static_ocr, is_before_ocr = _get_isbefore(o,c,r,src) 346 | is_static_t_ocr, is_together_ocr = _get_istogether(o,c,r,src) 347 | # chunks sent together must have same send and start time 348 | if is_static_t_ocr and is_together_ocr == 1: 349 | for l in range(self.topology.link(src,r)): 350 | if (c,src,r,l) in self.is_sent_set_1 and (o,src,r,l) in self.is_sent_set_1: 351 | opt.addLConstr(self.send[c,src,r,l] == self.send[o,src,r,l]) 352 | opt.addLConstr(self.start[c,r] == self.start[o,r]) 353 | elif not is_static_t_ocr: 354 | for l in range(self.topology.link(src,r)): 355 | if (c,src,r,l) in self.is_sent_set_1 and (o,src,r,l) in self.is_sent_set_1: 356 | opt.addGenConstrIndicator(self.is_together[(o,c,r,src)], True, self.send[c,src,r,l] == self.send[o,src,r,l]) 357 | 358 | 359 | if is_static_cor and is_static_ocr and is_static_t_ocr: 360 | sent_same = any([1 if (c,src,r,l) in self.is_sent_set_1 and (o,src,r,l) in self.is_sent_set_1 else 0 for l in range(L)]) 361 | sent_val = 1 if sent_same else 0 362 | assert is_before_cor + is_before_ocr + is_together_ocr == sent_val, f'{c}, {o}, {r}, {is_before_cor}, {is_before_ocr}, {is_together_ocr}, {sent_val}' 363 | 364 | # Bandwidth constraints based on chunk send times 365 | for l in range(self.topology.link(src,r)): 366 | if (c,src,r,l) in self.is_sent_set_1 and (o,src,r,l) in self.is_sent_set_1: 367 | lat_o = calc_latency(src,r,l,o) 368 | lat_c = calc_latency(src,r,l,c) 369 | 370 | if (c,o,r,src) in self.is_before_set_1: 371 | # print(c,"is_before",o, "for", src, "to", r) 372 | opt.addLConstr(self.send[c,src,r,l] + lat_c <= self.send[o,src,r,l]) 373 | elif (c,o,r,src) in self.is_before: 374 | # print(c,"may be before",o, "for", src, "to", r) 375 | opt.addLConstr(self.send[c,src,r,l] + lat_c <= self.send[o,src,r,l] + M*(1-self.is_before[(c,o,r,src)])) 376 | if (o,c,r,src) in self.is_before_set_1: 377 | # print(o,"is_before",c, "for", src, "to", r) 378 | opt.addLConstr(self.send[o,src,r,l] + lat_o <= self.send[c,src,r,l]) 379 | elif (o,c,r,src) in self.is_before: 380 | # print(o,"may be before",c, "for", src, "to", r) 381 | opt.addLConstr(self.send[o,src,r,l] + lat_o <= self.send[c,src,r,l] + M*(1-self.is_before[(o,c,r,src)])) 382 | 383 | # Order receives from a switch 384 | for (c,src,r,l) in self.is_sent_set_1: 385 | if (src,r) in self.topology.switches_involved: 386 | for swt_i, swt_type in self.topology.switches_involved[(src,r)]: 387 | srcs_check = [] 388 | if l == swt_i: 389 | for srcs, dsts, _, _, switch_name in self.topology.switches[swt_i]: 390 | if r in dsts and "in" in switch_name and src in srcs: 391 | srcs_check = srcs 392 | assert len(srcs_check)>0, f'{r} {c} {src} {l} {self.topology.switches[l]}' 393 | break 394 | lat_c = calc_latency(src,r,l,c) 395 | for o in range(c): 396 | for src_o in srcs_check: 397 | if src_o == src: 398 | continue 399 | if (o,src_o,r,l) in self.is_sent_set_1: 400 | if o == c: 401 | assert False 402 | lat_o = calc_latency(src_o,r,l,o) 403 | if (o,c,r,l,src_o,src) in self.recv_first_set_1: 404 | opt.addLConstr(self.send[o,src_o,r,l] + lat_o <= self.send[c,src,r,l]) 405 | elif (c,o,r,l,src,src_o) in self.recv_first_set_1: 406 | opt.addLConstr(self.send[c,src,r,l] + lat_c <= self.send[o,src_o,r,l]) 407 | else: 408 | assert False, f"no-ordering {o}, {c}, {r}, {src}, {src_o}" 409 | assert (o,c,r,l) not in self.recv_first, f'{o},{c},{r},{l}' 410 | self.recv_first[(o,c,r,l)] = opt.addVar(vtype=GRB.BINARY) 411 | opt.addLConstr(self.start[o,r] + lat_c <= self.start[c,r] + M*(1-self.recv_first[(o,c,r,l)])) 412 | opt.addLConstr(self.start[c,r] + lat_o <= self.start[o,r] + M*(self.recv_first[(o,c,r,l)])) 413 | 414 | # Order sends to a switch 415 | for (c,r,dst,l) in self.is_sent_set_1: 416 | if (r,dst) in self.topology.switches_involved: 417 | for swt_i, swt_type in self.topology.switches_involved[(r,dst)]: 418 | dsts_check = [] 419 | if l == swt_i: 420 | for srcs, dsts, _, _, switch_name in self.topology.switches[swt_i]: 421 | if r in srcs and "out" in switch_name and dst in dsts: 422 | dsts_check = dsts 423 | assert len(dsts_check)>0, f'{r} {c} {dst} {l} {self.topology.switches[l]}' 424 | break 425 | lat_c = calc_latency(r,dst,l,c) 426 | for o in range(c+1): 427 | for dst_o in dsts_check: 428 | if dst_o == dst: 429 | continue 430 | if (o,r,dst_o,l) in self.is_sent_set_1: 431 | lat_o = calc_latency(r,dst_o,l,o) 432 | if (o,c,r,l,dst_o,dst) in self.send_first_set_1: 433 | opt.addLConstr(self.send[o,r,dst_o,l] + lat_o <= self.send[c,r,dst,l]) 434 | elif (c,o,r,l,dst,dst_o) in self.send_first_set_1: 435 | opt.addLConstr(self.send[c,r,dst,l] + lat_c <= self.send[o,r,dst_o,l]) 436 | else: 437 | assert False 438 | assert (o,c,r,l) not in self.send_first, f'{o},{c},{r},{l}' 439 | self.send_first[(o,c,r,l)] = opt.addVar(vtype=GRB.BINARY) 440 | opt.addLConstr(self.send[o,r,dst_o,l] + lat_o <= self.send[c,r,dst,l] + M*(1-self.send_first[(o,c,r,l)])) 441 | opt.addLConstr(self.send[c,r,dst,l] + lat_c <= self.send[o,r,dst_o,l] + M*(self.send_first[(o,c,r,l)])) 442 | 443 | if prefer_local_reduce_first and len(self.weighted_terms_to_min): 444 | print("Weighted terms will be minimized:") 445 | print(self.weighted_terms_to_min) 446 | opt.setObjective(self.time + 0.001 * quicksum([term for term in self.weighted_terms_to_min]), GRB.MINIMIZE) 447 | else: 448 | opt.setObjective(self.time, GRB.MINIMIZE) 449 | 450 | def optimize_reversed(self, chunk_order=None, time_recv=None, 451 | switch_chunk_recv=None, switch_time_recv=None, switch_chunk_send=None, switch_time_send=None, 452 | nic_chunk_recv=None, nic_time_recv=None, nic_chunk_send=None, nic_time_send=None, 453 | switch_link_mapping_recv=None, switch_link_mapping_send=None, paths=None, prefer_local_reduce_first=False): 454 | 455 | C = self.collective.num_chunks 456 | R = self.collective.num_nodes 457 | L = self.topology.L 458 | 459 | from time import time 460 | endpoints_rc = [] 461 | for c in paths: 462 | for path in paths[c]: 463 | last_transfer_r = path[0][1] 464 | endpoints_rc.append((last_transfer_r,c)) 465 | self.topology.reverse_links() 466 | 467 | start_time = time() 468 | opt = Model('taccl_{}_{}'.format(self.topology.name, self.collective.name)) 469 | 470 | # call to _encode swaps the order of switch_link_mapping_send and switch_link_mapping_recv 471 | self._encode(opt, chunk_order, time_recv, 472 | switch_chunk_recv, switch_time_recv, switch_chunk_send, switch_time_send, 473 | nic_chunk_recv, nic_time_recv, nic_chunk_send, nic_time_send, 474 | switch_link_mapping_send, switch_link_mapping_recv, endpoints_rc, self.route_sketch.hyperparameters.heuristic, prefer_local_reduce_first) 475 | opt.optimize() 476 | end_time = time() 477 | print("strict time (encode+solve)", end_time-start_time, flush=True) 478 | 479 | if opt.status == GRB.INFEASIBLE: 480 | opt.computeIIS() 481 | opt.write("model.ilp") 482 | raise ValueError("Infeasible model") 483 | 484 | send_dict = defaultdict(list) 485 | SCALE_TIME = 10 486 | 487 | model_str = "" 488 | other_model_str = "" 489 | for c in range(C): 490 | for r in range(R): 491 | if self.start[c,r].X <= self.time.X + 0.005: 492 | model_str += f'start[{c},{r}]={self.start[c,r].X}\n' 493 | recv_times = defaultdict(list) 494 | chunk_path = [defaultdict(list) for c in range(C)] 495 | for src in range(R): 496 | for r in self.topology.destinations(src): 497 | for l in range(L): 498 | for c_np in chunk_order[r][src][l]: 499 | c = int(c_np) 500 | assert (c,src,r,l) in self.is_sent_set_1 501 | # model_str += f'{c}: {src} --{l}--> {r} t={self.send[c,src,r,l].X}\n' 502 | t = int(SCALE_TIME*self.send[c,src,r,l].X + 0.0001) 503 | transfer_str = f'{c}: {src} --{l}--> {r} t={self.send[c,src,r,l].X}\n' 504 | recv_times[t].append(transfer_str) 505 | chunk_path[c][t].append(transfer_str) 506 | send_dict[t].append([c,src,r,t,l,'rrc']) 507 | for c_np in range(C): 508 | c = int(c_np) 509 | if c not in chunk_order[r][src][l]: 510 | assert (c,src,r,l) not in self.is_sent_set_1 511 | for tval in sorted(recv_times.keys()): 512 | for strval in recv_times[tval]: 513 | model_str += strval 514 | for c in range(C): 515 | for tval in sorted(chunk_path[c].keys()): 516 | for strval in chunk_path[c][tval]: 517 | other_model_str += strval 518 | for c in range(C): 519 | for o in range(c): 520 | for r in range(R): 521 | for src in self.topology.sources(r): 522 | if (o,c,r,src) in self.is_together: 523 | if self.is_together[(o,c,r,src)].X >= 0.995: 524 | print(f'is_together[{o},{c},{r},{src}] = {self.is_together[(o,c,r,src)].X}') 525 | model_str += f'({c},{o},{r},{src})\n' 526 | elif (o,c,r,src) in self.is_together_set_1: 527 | model_str += f'({c},{o},{r},{src}) set\n' 528 | print(f'({c},{o},{r},{src}) set together') 529 | if (c,o,r,src) in self.is_before and self.is_before[(c,o,r,src)].X >= 0.995: 530 | print(f'is_before[{c},{o},{r},{src}]') 531 | if (o,c,r,src) in self.is_before and self.is_before[(o,c,r,src)].X >= 0.995: 532 | print(f'is_before[{o},{c},{r},{src}]') 533 | 534 | print(model_str) 535 | print("Chunk path:") 536 | print(other_model_str) 537 | return send_dict 538 | 539 | 540 | def build_allreduce(self, reduce_coll, send_dict_redscat, send_dict_allgather, ts): 541 | import math 542 | 543 | C = self.collective.num_chunks 544 | R = self.collective.num_nodes 545 | L = self.topology.L 546 | print(R,C,L) 547 | 548 | SCALE_TIME = 10 549 | 550 | do_redscat = True 551 | do_allgather = True 552 | 553 | # assert len(ts) 554 | # send_dict_allgather = np.load(f"send_dict_{ts}.npy", allow_pickle=True).item() 555 | steps=[] 556 | send_times_redscat = sorted(send_dict_redscat.keys()) 557 | print("senddicts:") 558 | print("send_dict_redscat:", send_dict_redscat) 559 | print("send_dict_allgather:", send_dict_allgather) 560 | tmax = send_times_redscat[-1] 561 | print("tmax", tmax) 562 | shifted_send_dict_allgather = defaultdict(list) 563 | for t in send_dict_allgather: 564 | for (c,src,r,t_,l) in send_dict_allgather[t]: 565 | t_shifted = tmax + t + SCALE_TIME 566 | shifted_send_dict_allgather[t_shifted].append([c,src,r,t_shifted,l,None]) # reverse 567 | 568 | if do_redscat and do_allgather: 569 | send_dict = send_dict_redscat.copy() 570 | send_dict.update(shifted_send_dict_allgather) 571 | elif do_redscat: 572 | send_dict = send_dict_redscat.copy() 573 | elif do_allgather: 574 | send_dict = shifted_send_dict_allgather.copy() 575 | self.topology.reverse_links() 576 | print("send_dict:", send_dict) 577 | send_times = sorted(send_dict.keys()) 578 | 579 | i = 0 580 | while(i < len(send_times)): 581 | num_sends = [[0 for _ in range(R)] for _ in range(R)] 582 | j = i + 1 583 | while j < len(send_times): 584 | to_break = False 585 | t_end = send_times[j] 586 | for (c,src,r,_,_,redop) in send_dict[t_end]: 587 | for t in range(i,j): 588 | for (ci,srci,ri,_,_,redopi) in send_dict[send_times[t]]: 589 | if (c == ci and src == ri) or (c == ci and r == ri and redop is not None and redopi is not None): 590 | to_break = True 591 | break 592 | if to_break: 593 | break 594 | if to_break: 595 | break 596 | if to_break: 597 | break 598 | j = j + 1 599 | sends = [] 600 | for k in range(i,j): 601 | sends.extend(send_dict[send_times[k]]) 602 | print(sends) 603 | num_sends = [[[0 for _ in range(L)] for _ in range(R)] for _ in range(R)] 604 | for (c,src,r,_,l,_) in sends: 605 | num_sends[r][src][l] = num_sends[r][src][l] + 1 606 | rounds = 0 607 | for srcs, dsts, bw, l, name in self.topology.real_bandwidth_constraints(): 608 | util = 0 609 | for dst in dsts: 610 | for src in srcs: 611 | util += num_sends[dst][src][l] 612 | if rounds <= util * bw * SCALE_TIME: 613 | rounds = math.ceil(util * bw * SCALE_TIME) 614 | step = Step(rounds, sorted(sends, key=lambda x: x[3])) 615 | print("STEP ", step) 616 | steps.append(step) 617 | i = j 618 | 619 | if do_allgather and do_redscat: 620 | instance = Instance( 621 | steps=len(steps), 622 | extra_rounds=0, 623 | chunks=R*self.chunkup, 624 | ) 625 | elif do_redscat: 626 | instance = Instance( 627 | steps=len(steps), 628 | extra_rounds=0, 629 | chunks=self.chunkup, 630 | ) 631 | for step in steps: 632 | print(step) 633 | elif do_allgather: 634 | instance = Instance( 635 | steps=len(steps), 636 | extra_rounds=0, 637 | chunks=self.chunkup, 638 | ) 639 | 640 | if do_redscat and do_allgather: 641 | soltype = f"{ts}-allreduce" 642 | elif do_redscat: 643 | soltype = f"{ts}-redscat" 644 | elif do_allgather: 645 | soltype = f"{ts}-allgather" 646 | 647 | from time import time 648 | timestamp = int(time()) 649 | np.save(f'send_dict_allred_{timestamp}.npy', send_dict) 650 | return Algorithm.make_implementation(reduce_coll, self.topology, instance, steps, cont=True, suffix=f'-gurobisol-{soltype}-{timestamp}') 651 | -------------------------------------------------------------------------------- /taccl/routing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from collections import defaultdict 5 | from gurobipy import GRB, Model, quicksum 6 | from taccl.algorithm import * 7 | from taccl.shortest_path_sets import * 8 | from taccl.topologies.route_sketch import * 9 | from taccl.utils import * 10 | from time import time 11 | 12 | verbose = True 13 | 14 | class TACCLRouting: 15 | def __init__(self, topology, route_sketch, collective): 16 | self.topology = topology 17 | self.route_sketch = route_sketch 18 | self.collective = collective 19 | self.chunkup = self.route_sketch.hyperparameters.chunkup 20 | 21 | def latency(self, src, dst, l): 22 | return self.topology.get_invbw(src,dst) 23 | 24 | def _encode(self, opt): 25 | # print("topology", self.topology.name) 26 | # print(self.topology.links) 27 | # print(self.topology.invbws) 28 | C = self.collective.num_chunks 29 | R = self.collective.num_nodes 30 | M = 10000000 # big-M for maximum self.time between sends 31 | ST = 5000000 # self.time for unstarted starts 32 | SND = 10000000 # self.time for unsent sends 33 | L = self.topology.L 34 | 35 | opt.Params.MIPFocus = 1 36 | opt.Params.Method = 2 37 | opt.Params.NumericFocus = 3 38 | opt.Params.Threads = 12 39 | opt.Params.MIPGap = 1e-9 40 | opt.Params.TimeLimit = 1800 41 | opt.Params.IntegralityFocus = 1 42 | opt.Params.IntFeasTol = 1e-9 43 | opt.Params.FeasibilityTol = 1e-9 44 | 45 | mu = 0.01 46 | 47 | self.send = opt.addVars(C, R, R, L, name="send", vtype=GRB.CONTINUOUS, lb=0.0) 48 | self.start = opt.addVars(C, R, name="start", vtype=GRB.CONTINUOUS, lb=0.0) 49 | self.time = opt.addVar(name="time", vtype=GRB.CONTINUOUS) 50 | opt.addLConstr(self.time <= ST-1) 51 | self.is_sent = opt.addVars(C,R,R,L, name="is_sent", vtype=GRB.BINARY) 52 | if self.route_sketch.intranode.strategy == "minmax": 53 | self.max_link_util = opt.addVar(name="max_link_util", vtype=GRB.CONTINUOUS, lb=0.0) 54 | if self.route_sketch.intranode.strategy == "maxmin": 55 | self.min_link_util = opt.addVar(name="min_link_util", vtype=GRB.CONTINUOUS, lb=0.0) 56 | 57 | num_local_nodes = self.topology.num_nodes() // self.topology.copies 58 | 59 | opt.ModelSense = GRB.MINIMIZE 60 | 61 | # Don't send chunks over connections which are not linked 62 | for src in range(R): 63 | for dst in range(R): 64 | if dst not in self.topology.destinations(src): 65 | for c in self.collective.chunks(): 66 | for l in range(L): 67 | opt.addLConstr(self.is_sent[c,src,dst,l] == 0) 68 | opt.addLConstr(self.send[c,src,dst,l] == SND) 69 | else: 70 | num_links = self.topology.link(src,dst) 71 | for c in self.collective.chunks(): 72 | for l in range(num_links, L): 73 | opt.addLConstr(self.is_sent[c,src,dst,l] == 0) 74 | opt.addLConstr(self.send[c,src,dst,l] == SND) 75 | 76 | 77 | num_local_nodes = self.topology.num_nodes() // self.topology.copies 78 | for c in self.collective.chunks(): 79 | for r in self.collective.ranks(): 80 | opt.addLConstr(self.start[c,r] <= ST) 81 | # Fixing to only spsets will reduce chances for contiguity, but it is fine 82 | # Don't send to r if it is not in spset of c 83 | for src in self.topology.sources(r): 84 | if (r not in self.spsets[c]) or (src not in self.spsets[c]): 85 | for l in range(L): 86 | opt.addLConstr(self.send[c,src,r,l] == SND) 87 | opt.addLConstr(self.is_sent[c,src,r,l] == 0) 88 | if r not in self.spsets[c]: 89 | opt.addLConstr(self.start[c,r] == ST) 90 | continue 91 | 92 | if self.collective.precondition(r, c): 93 | # Have chunks start on their starting ranks before the first step 94 | opt.addLConstr(self.start[c,r] == 0) 95 | for src in self.topology.sources(r): 96 | for l in range(self.topology.link(src,r)): 97 | opt.addLConstr(self.is_sent[c,src,r,l] == 0) 98 | else: 99 | for src in self.topology.sources(r): 100 | for l in range(self.topology.link(src,r)): 101 | opt.addGenConstrIndicator(self.is_sent[c,src,r,l], True, self.start[c,r] == self.send[c,src,r,l] + self.latency(src,r,l)) 102 | opt.addGenConstrIndicator(self.is_sent[c,src,r,l], False, self.send[c,src,r,l] == SND) 103 | 104 | if self.collective.postcondition(r, c): 105 | opt.addLConstr(quicksum(quicksum(self.is_sent[c,src,r,l] for l in range(L)) for src in self.topology.sources(r)) == 1, name=f'post_{r}_{c}') 106 | # opt.addLConstr(quicksum(quicksum(self.is_sent[c,src,r,l] for l in range(L)) for src in range(R)) == 1) 107 | opt.addLConstr(self.start[c,r] <= self.time) 108 | else: 109 | opt.addLConstr(quicksum(quicksum(self.is_sent[c,src,r,l] for l in range(L)) for src in range(R)) <= 1, name=f'non_post_{r}_{c}') 110 | opt.addLConstr(self.start[c,r] <= self.time + M*(1-quicksum(quicksum(self.is_sent[c,src,r,l] for l in range(L)) for src in range(R)))) 111 | opt.addLConstr(self.start[c,r] >= self.time + 1 - M*(quicksum(quicksum(self.is_sent[c,src,r,l] for l in range(L)) for src in range(R)))) 112 | 113 | for src in self.topology.sources(r): 114 | for l in range(self.topology.link(src,r)): 115 | opt.addLConstr(self.start[c,src] <= self.send[c,src,r,l]) 116 | 117 | # Count total switch send and switch recv in bounding the time of algo 118 | for l, switches in enumerate(self.topology.switches): 119 | for srcs, dsts, _, swtbw, switch_name in switches: 120 | if "in" in switch_name: 121 | for dst in dsts: 122 | opt.addLConstr(self.time >= quicksum(quicksum(swtbw*self.is_sent[c,srci,dst,l] for c in range(C)) for srci in srcs), name=f'switchin_{dst}_{l}') 123 | if self.route_sketch.intranode.strategy == "minmax": 124 | opt.addLConstr(self.max_link_util >= quicksum(quicksum(swtbw*self.is_sent[c,srci,dst,l] for c in range(C)) for srci in srcs), name=f'Mx_switchin_{dst}_{l}') 125 | if self.route_sketch.intranode.strategy == "maxmin": 126 | opt.addLConstr(self.min_link_util <= quicksum(quicksum(swtbw*self.is_sent[c,srci,dst,l] for c in range(C)) for srci in srcs), name=f'mx_switchin_{dst}_{l}') 127 | if "out" in switch_name: 128 | for src in srcs: 129 | opt.addLConstr(self.time >= quicksum(quicksum(swtbw*self.is_sent[c,src,dsti,l] for c in range(C)) for dsti in dsts), name=f'switchout_{src}_{l}') 130 | if self.route_sketch.intranode.strategy == "minmax": 131 | opt.addLConstr(self.max_link_util >= quicksum(quicksum(swtbw*self.is_sent[c,src,dsti,l] for c in range(C)) for dsti in dsts), name=f'Mx_switchout_{src}_{l}') 132 | if self.route_sketch.intranode.strategy == "maxmin": 133 | opt.addLConstr(self.min_link_util <= quicksum(quicksum(swtbw*self.is_sent[c,src,dsti,l] for c in range(C)) for dsti in dsts), name=f'mx_switchout_{src}_{l}') 134 | for c in self.collective.chunks(): 135 | if src in self.spsets[c]: 136 | for dstj in dsts: 137 | opt.addGenConstrIndicator(self.is_sent[c,src,dstj,l], True, self.time >= self.start[c,src] + swtbw * quicksum(self.is_sent[c,src,dsti,l] for dsti in dsts)) 138 | 139 | # Count total link transfer in bounding the time of algo 140 | for r in self.collective.ranks(): 141 | for src in self.topology.sources(r): 142 | for l in range(self.topology.link(src,r)): 143 | opt.addLConstr(self.time >= quicksum(self.latency(src,r,l)*self.is_sent[c,src,r,l] for c in range(C))) 144 | if self.route_sketch.intranode.strategy == "minmax": 145 | opt.addLConstr(self.max_link_util >= quicksum(self.latency(src,r,l)*self.is_sent[c,src,r,l] for c in range(C))) 146 | if self.route_sketch.intranode.strategy == "maxmin": 147 | opt.addLConstr(self.min_link_util <= quicksum(self.latency(src,r,l)*self.is_sent[c,src,r,l] for c in range(C))) 148 | 149 | if isinstance(self.route_sketch.intranode, IntraNode_Switch): 150 | self._add_min_max_unique(opt, num_local_nodes, mu, L) 151 | 152 | if self.topology.copies > 1: 153 | self._add_relay_relaxation(opt, SND) 154 | if self.route_sketch.internode.enforce_ordering: 155 | self._enforce_ordering(opt) 156 | 157 | self._add_symmetry(opt, L) 158 | 159 | if isinstance(self.route_sketch.intranode, IntraNode_Switch): 160 | if self.route_sketch.intranode.switch_hyperedge_strategy[0] == "uc-min": 161 | print("--- minUniqueSends") 162 | opt.setObjective(self.time + self.mu * self.unique_links, GRB.MINIMIZE) 163 | elif self.route_sketch.intranode.switch_hyperedge_strategy[0] == "uc-max": 164 | print("--- maxUniqueSends") 165 | opt.setObjective(self.time - self.mu * self.unique_links, GRB.MINIMIZE) 166 | else: 167 | pass 168 | elif self.route_sketch.intranode.strategy == "minmax": 169 | print("minimizing maximum link utilization") 170 | opt.setObjective(self.time + self.mu * self.max_link_util, GRB.MINIMIZE) 171 | elif self.route_sketch.intranode.strategy == "maxmin": 172 | print("maximizing minimum link utilization") # To do better load balancing 173 | opt.setObjective(self.time - self.mu * self.min_link_util, GRB.MINIMIZE) 174 | else: 175 | opt.setObjective(self.time, GRB.MINIMIZE) 176 | 177 | def _enforce_ordering(self, opt): 178 | print("--- _enforce_ordering") 179 | assert self.route_sketch.internode.gpus_to_sender_rev_map is not None 180 | # Send the chunks of inter-node sender first and then the chunks of other gpus that are mapped to the inter-node sender 181 | sender_to_gpu = self.route_sketch.internode.gpus_to_sender_rev_map 182 | for sender in sender_to_gpu: 183 | for cp in range(self.topology.copies): 184 | src = self.topology.base_gpus[cp] + int(sender) 185 | all_chunks = [c for gpu in sender_to_gpu[sender] for c in self.collective.pre_chunk(self.topology.base_gpus[cp] + gpu)] 186 | sender_chunks = [c for c in self.collective.pre_chunk(src)] 187 | for r in self.topology.destinations(src): 188 | if self.topology.gpu_to_node(r) != self.topology.gpu_to_node(src): 189 | for c in all_chunks: 190 | if c not in sender_chunks: 191 | for c_sender in sender_chunks: 192 | for l in range(self.topology.link(src,r)): 193 | opt.addGenConstrIndicator(self.is_sent[c_sender,src,r,l], True, self.send[c,src,r,l] >= self.send[c_sender,src,r,l] + self.latency(src,r,l)) 194 | 195 | def sym_rank(self, r, i, sym_offset, sym_size): 196 | return (r % sym_size + i * sym_offset) % sym_size + (r // sym_size) * sym_size 197 | 198 | def sym_chunk(self, c, i, sym_offset, sym_size): 199 | c_offset = c % self.chunkup 200 | # This method of find symmetric chunk works for Alltoall and Allgather 201 | # For Alltoall, there is a single pre and post rank for each chunk 202 | # For Allgather, there is a single pre rank for each chunk, thus still allowing a quick match 203 | for r_pre in self.collective.pre_rank(c): 204 | break 205 | for r_post in self.collective.post_rank(c): 206 | break 207 | r_pre_sym = self.sym_rank(r_pre, i, sym_offset, sym_size) 208 | r_post_sym = self.sym_rank(r_post, i, sym_offset, sym_size) 209 | c_sym = -1 210 | for c_opt in self.collective.post_chunk(r_post_sym): 211 | if self.collective.precondition(r_pre_sym, c_opt) and c_opt % self.chunkup == c_offset: 212 | assert c_sym == -1 213 | c_sym = c_opt 214 | return c_sym, r_pre_sym, r_post_sym 215 | 216 | def _add_symmetry(self, opt, L): 217 | print("--- _add_symmetry") 218 | num_nodes = self.topology.num_nodes() 219 | count = len(self.route_sketch.symmetry.offsets) 220 | for (sym_offset, sym_size) in self.route_sketch.symmetry.offsets: 221 | already_added = [] 222 | for c in self.collective.chunks(): 223 | c_sym, r_pre_sym, r_post_sym = self.sym_chunk(c, 1, sym_offset, sym_size) 224 | if c_sym == -1: 225 | assert False, "Collective is not symmetric" 226 | pair_c = (c, c_sym) if c <= c_sym else (c_sym, c) 227 | if pair_c in already_added: 228 | continue 229 | already_added.append(pair_c) 230 | for r in range(num_nodes): 231 | r_sym = self.sym_rank(r, 1, sym_offset, sym_size) 232 | for src in range(num_nodes): 233 | if (r // sym_size == src // sym_size): 234 | src_sym = self.sym_rank(src, 1, sym_offset, sym_size) 235 | for l in range(L): 236 | opt.addLConstr(self.send[c,src,r,l] == self.send[c_sym, src_sym, r_sym, l]) 237 | opt.addLConstr(self.is_sent[c,src,r,l] == self.is_sent[c_sym, src_sym, r_sym, l], name=f'sym_{c}_{src}_{r}_{src_sym}_{r_sym}_{l}') 238 | opt.addLConstr(self.start[c,r] == self.start[c_sym, r_sym]) 239 | 240 | 241 | def _add_relay_relaxation(self, opt, SND): 242 | print("--- _add_relay_relaxation_new") 243 | num_local_nodes = self.topology.num_nodes() // self.topology.copies 244 | chunk_to_sender_map = defaultdict(list) 245 | if self.route_sketch.internode.gpus_to_sender_rev_map is not None: 246 | for sender in self.route_sketch.internode.gpus_to_sender_rev_map: 247 | for gpu_src in self.route_sketch.internode.gpus_to_sender_rev_map[sender]: 248 | for i in range(self.topology.copies): 249 | node_sender = int(sender) + self.topology.base_gpus[i] 250 | node_src = gpu_src + self.topology.base_gpus[i] 251 | for c in self.collective.pre_chunk(node_src): 252 | chunk_to_sender_map[c].append(node_sender) 253 | 254 | for (strategy, nnodes, group_size) in zip(self.route_sketch.multinode.strategy, self.route_sketch.multinode.nnodes, self.route_sketch.multinode.group_size): 255 | if strategy == "round-robin" or strategy == "relay": 256 | all_gpus = defaultdict() 257 | num_groups = self.topology.copies // group_size 258 | for base_n in range(0, group_size*num_groups, nnodes): 259 | # print("base_n", base_n, "nnodes", nnodes, "group_size", group_size, "num_groups", num_groups) 260 | all_gpus[base_n] = [g for g in range(self.topology.base_gpus[base_n], self.topology.base_gpus[base_n + nnodes])] 261 | for c in self.collective.chunks(): 262 | pair_set = defaultdict(set) 263 | for r1 in self.collective.pre_rank(c): 264 | for r2 in self.collective.post_rank(c): 265 | n1 = self.topology.gpu_to_node(r1) 266 | n2 = self.topology.gpu_to_node(r2) 267 | base_n1 = (n1 // nnodes) * nnodes 268 | base_n2 = (n2 // nnodes) * nnodes 269 | if (base_n1 != base_n2) and (n1 // group_size == n2 // group_size): 270 | senders = all_gpus[base_n1] 271 | receivers = all_gpus[base_n2] 272 | if self.route_sketch.internode.gpus_to_sender_rev_map is not None: 273 | assert c in chunk_to_sender_map 274 | assert len(set(chunk_to_sender_map[c]) & set(senders)) == len(set(chunk_to_sender_map[c])) 275 | senders = [g for g in chunk_to_sender_map[c]] 276 | # remove senders and receivers that are not in spsets 277 | senders = list(filter(lambda x: x in self.spsets[c], senders)) 278 | receivers = list(filter(lambda x: x in self.spsets[c], receivers)) 279 | for s in senders: 280 | for r in receivers: 281 | if self.topology.link(s,r) > 0: 282 | pair_set[(base_n1,base_n2)].add((s,r)) 283 | # print("pair set", c, pair_set) 284 | for (bn1, bn2) in pair_set: 285 | opt.addLConstr(quicksum(self.is_sent[c,src,r,l] for (src,r) in pair_set[(bn1,bn2)] for l in range(self.topology.link(src,r))) >= 1) 286 | for src in all_gpus[bn1]: 287 | for r in all_gpus[bn2]: 288 | if (src,r) not in pair_set[(bn1,bn2)]: 289 | for l in range(self.topology.link(src,r)): 290 | opt.addLConstr(self.send[c,src,r,l] == SND) 291 | opt.addLConstr(self.is_sent[c,src,r,l] == 0, name=f'relay_notSend_{c}_{src}_{r}_{l}') 292 | assert (r,src) not in pair_set[(bn1,bn2)] 293 | for l in range(self.topology.link(r,src)): 294 | opt.addLConstr(self.send[c,r,src,l] == SND) 295 | opt.addLConstr(self.is_sent[c,r,src,l] == 0, name=f'relay_notSend_{c}_{r}_{src}_{l}') 296 | # If c doesn't need to be sent outside the node, then set all internode transfers for that chunk to 0 297 | if len(pair_set) == 0: 298 | for r1 in self.collective.pre_rank(c): 299 | n1 = self.topology.gpu_to_node(r1) 300 | base_n1 = (n1 // nnodes) * nnodes 301 | for src in self.collective.ranks(): 302 | for r in self.collective.ranks(): 303 | if (src not in all_gpus[base_n1]) or (r not in all_gpus[base_n1]): 304 | for l in range(self.topology.link(src,r)): 305 | opt.addLConstr(self.send[c,src,r,l] == SND) 306 | opt.addLConstr(self.is_sent[c,src,r,l] == 0, name=f'relay_sendNotNeeded_{c}_{src}_{r}_{l}') 307 | 308 | break 309 | 310 | elif strategy == "ring": 311 | assert False, "Ring strategy is not yet implemented" 312 | else: 313 | assert False, "strategy is not defined" 314 | 315 | def _add_min_max_unique(self, opt, num_local_nodes, mu, L): 316 | print("--- _add_min_max_unique") 317 | # print("SEND_AT_ALL") 318 | self.send_at_all = opt.addVars(num_local_nodes,num_local_nodes,L, name="send_at_all", vtype=GRB.BINARY) 319 | 320 | for r in range(num_local_nodes): 321 | for src in self.topology.sources(r): 322 | if src < num_local_nodes: 323 | for l in range(L): 324 | for c in self.collective.chunks(): 325 | opt.addLConstr(self.send_at_all[src,r,l] >= self.is_sent[c,src,r,l]) 326 | opt.addLConstr(self.send_at_all[src,r,l] <= quicksum(self.is_sent[c,src,r,l] for c in self.collective.chunks())) 327 | # print("mu", mu) 328 | self.mu = mu 329 | self.unique_links = quicksum(self.send_at_all[src,r,l] for l in range(L) for r in range(num_local_nodes) for src in self.topology.sources(r) if src < num_local_nodes) 330 | 331 | def optimize(self, distribute_over_links): 332 | import pickle as pkl 333 | heuristic = self.route_sketch.hyperparameters.heuristic 334 | # print("HEURISTIC:", heuristic) 335 | # print("finding shortest path sets") 336 | self.spsets = shortest_path_sets(self.topology, self.collective) 337 | 338 | # print(self.spsets) 339 | # print("found shortest path sets") 340 | instance_name = 'sccl_{}_{}_gurobiSimple'.format(self.topology.name, self.collective.name) 341 | 342 | C = self.collective.num_chunks 343 | R = self.collective.num_nodes 344 | L = self.topology.L 345 | 346 | start_time = time() 347 | opt = Model(instance_name) 348 | self._encode(opt) 349 | opt.optimize() 350 | end_time = time() 351 | print("simple time (encode+solve)", end_time-start_time, flush=True) 352 | 353 | opt.write(f'model_{instance_name}.lp') 354 | if opt.status == GRB.INFEASIBLE: 355 | opt.computeIIS() 356 | opt.write(f'model_{instance_name}.ilp') 357 | raise ValueError("Infeasible model") 358 | 359 | 360 | num_sols = 1 361 | for sol_i in range(num_sols): 362 | opt.Params.SolutionNumber = sol_i 363 | time_recv = [[[[] for l in range(L)] for src in range(R)] for r in range(R)] 364 | chunk_recv = [[[[] for l in range(L)] for src in range(R)] for r in range(R)] 365 | time_send = [[[[] for l in range(L)] for src in range(R)] for r in range(R)] 366 | chunk_send = [[[[] for l in range(L)] for src in range(R)] for r in range(R)] 367 | 368 | model_str = "" 369 | for c in range(C): 370 | for r in range(R): 371 | if self.start[c,r].Xn <= self.time.Xn + 0.005: 372 | model_str += f'start[{c},{r}]={self.start[c,r].Xn}\n' 373 | dist_link_heuristic = [3,5,8,9,10,11,13] # Distribute chunk sends if there are multiple links connecting src to r 374 | if distribute_over_links: 375 | assert heuristic in dist_link_heuristic 376 | for c in range(C): 377 | sratch_str = defaultdict(list) 378 | for r in range(R): 379 | for src in self.topology.sources(r): 380 | for l in range(L): 381 | if self.is_sent[c,src,r,l].Xn >= 0.995: 382 | t_val = self.send[c,src,r,l].Xn 383 | sratch_str[t_val].append(f'{c}: {src} --{l}--> {r} t={self.send[c,src,r,l].Xn}\n') 384 | # model_str += f'{c}: {src} --{l}--> {r} t={self.send[c,src,r,l].Xn}\n' 385 | if distribute_over_links: 386 | chunk_send[src][r][0].append(c) 387 | time_send[src][r][0].append(int(self.send[c,src,r,l].Xn + 0.005)) 388 | chunk_recv[r][src][0].append(c) 389 | time_recv[r][src][0].append(int(self.start[c,r].Xn + 0.005)) 390 | else: 391 | chunk_send[src][r][l].append(c) 392 | time_send[src][r][l].append(int(self.send[c,src,r,l].Xn + 0.005)) 393 | chunk_recv[r][src][l].append(c) 394 | time_recv[r][src][l].append(int(self.start[c,r].Xn + 0.005)) 395 | for tval in sorted(sratch_str.keys()): 396 | for strval in sratch_str[tval]: 397 | model_str += strval 398 | # NOTE: we round the start and send times so integer here. 399 | # Would be good to have integral latencies for the path encoding 400 | # print(model_str) 401 | time_new = int(time()) 402 | print(f"Saving cs_ts_cr_tr_simple_{time_new}") 403 | with open(f'cs_ts_cr_tr_simple_{time_new}.pkl', 'wb') as f: 404 | pkl.dump([chunk_send, time_send, chunk_recv, time_recv], f) 405 | 406 | return chunk_send, time_send, chunk_recv, time_recv 407 | 408 | def check_heuristic(self, topology, route_sketch, collective, ts_heur): 409 | import pickle as pkl 410 | assert ts_heur is not None 411 | print(f"Checking sol obtained by heuristic {route_sketch.hyperparameters.heuristic} ts={ts_heur}") 412 | with open(f'cs_ts_cr_tr_simple_{ts_heur}.pkl', 'rb') as f: 413 | chunk_send, time_send, chunk_recv, time_recv = pkl.load(f) 414 | -------------------------------------------------------------------------------- /taccl/scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from taccl.algorithm import * 5 | from taccl.instance import * 6 | from taccl.shortest_path_sets import * 7 | from gurobipy import GRB, Model, quicksum, abs_, and_ 8 | from taccl.utils import * 9 | import numpy as np 10 | 11 | class TACCLScheduler(object): 12 | def __init__(self, topology, route_sketch, collective): 13 | self.topology = topology 14 | self.route_sketch = route_sketch 15 | self.collective = collective 16 | 17 | # Don't care about relay relaxation - gurobi simple fixes that 18 | def _is_relay_link(self,r,dst): 19 | if self.topology.gpu_to_node(r) != self.topology.gpu_to_node(dst): 20 | return True 21 | return False 22 | 23 | def _encode(self, opt, chunk_order, chunk_time, 24 | switch_chunk_order_recv, switch_chunk_time_recv, switch_chunk_order_send, switch_chunk_time_send, 25 | nic_chunk_order_recv, nic_chunk_time_recv, nic_chunk_order_send, nic_chunk_time_send, 26 | switch_link_mapping_recv=None, switch_link_mapping_send=None, extra_heuristic=True): 27 | 28 | C = self.collective.num_chunks 29 | R = self.collective.num_nodes 30 | L = self.topology.L 31 | heuristic = self.route_sketch.hyperparameters.heuristic 32 | smallM = 10 33 | M = 10000000 # big-M for maximum self.time between sends 34 | ST = 500000 # self.time for unsent sends and unstarted starts 35 | SND = 1000000 # self.time for unsent sends and unstarted starts 36 | opt.Params.Threads = 1 37 | opt.Params.IntegralityFocus = 1 38 | opt.Params.IntFeasTol = 1e-9 39 | opt.Params.FeasibilityTol = 1e-9 40 | opt.Params.TimeLimit = 1800 41 | 42 | self.spsets = shortest_path_sets(self.topology, self.collective) 43 | num_local_nodes = R // self.topology.copies 44 | 45 | self.is_sent_set_1 = set() 46 | self.is_before_set_1 = set() # Fixed ordering of chunks over a link 47 | self.is_together_set_0 = set() # Fix if chunks are received together on a GPU 48 | self.is_together_set_1 = set() 49 | self.recv_first_set_1 = set() # Fixed ordering between recvs on a switch 50 | self.nic_recv_first_set_1 = set() # Fixed ordering between recvs on a NIC 51 | self.send_first_set_1 = set() # Fixed ordering between sends on a switch 52 | self.nic_send_first_set_1 = set() # Fixed ordering between sends on a NIC 53 | 54 | self.is_before = {} # (c,o,r): c is received on GPU r before o, but from same source 55 | self.is_together = {} # (c,o,r): c dst only if r -> dst is an IB 73 | def _should_try_together(r,dst,c,o): 74 | assert r != dst 75 | if self._is_relay_link(r,dst): 76 | return True 77 | return False 78 | 79 | # Can fix contiguous sends if reqd 80 | def _should_fix_together(r,dst,l,c,o): 81 | return False 82 | if not (isinstance(self.topology, DistributedTopology) and self.topology.m_top == MachineTopology.RELAYED): 83 | return False 84 | for rc in self.collective.pre_on(c): 85 | r1 = rc 86 | for ro in self.collective.pre_on(o): 87 | r2 = ro 88 | assert r != dst 89 | if self._is_relay_link(r,dst): 90 | if self.topology.bw_dist[rc][r] == self.topology.bw_dist[ro][r]: 91 | return True 92 | return False 93 | 94 | # Populate is_sent_set_1 from the chunk_order received from path encoding 95 | def _add_chunk_sent(opt, heuristic): 96 | if chunk_order is not None: 97 | assert chunk_time is not None 98 | assert len(chunk_order) == R 99 | assert len(chunk_order[0]) == R 100 | assert len(chunk_order[0][0]) <= L 101 | for r in range(R): 102 | for src in self.topology.sources(r): 103 | for l in range(self.topology.link(src,r)): 104 | for c in chunk_order[r][src][l]: 105 | assert r in self.spsets[c] 106 | self.is_sent_set_1.add((c,src,r,l)) 107 | 108 | def _add_switch_order(switch_chunk_order_recv, switch_chunk_order_send, switch_link_mapping_recv, switch_link_mapping_send): 109 | # Creating new datastructures from order information 110 | # Order recvs coming into and going out from a GPU connected to a switch 111 | # Input: 112 | # switch_chunk_order_recv[r][ll] : [(c1,src1), ...] in order 113 | # switch_chunk_order_send[r][ll] : [(c1,dst1), ...] in order 114 | # switch_link_mapping[r][ll] = l 115 | # Output: 116 | # recv_right_after[r][ll][(c,srci)] = (o,srcj) => GPU r over switch l receives o from srcj right after receiving c from srci 117 | # send_right_after[r][ll][(c,dsti)] = (o,dstj) => GPU r over switch l sends o to dstj right after sending c to dsti 118 | # recv_ and send_ right_after give the first chunk received from / sent to a different GPU 119 | # (c,o,r,l,srci,srcj) \in recv_first_set_1 => c is recvd on r from srci anytime before o is recvd on r from srcj 120 | # (c,o,r,l,dsti,dstj) \in send_first_set_1 => c is sent from r to dsti anytime before o is sent from r to dstj 121 | # Note that the l and ll are different for right_after and first_set 122 | LL = 0 123 | for r in range(R): 124 | LL = max(LL, len(switch_chunk_order_recv[r])) 125 | LL = max(LL, len(switch_chunk_order_send[r])) 126 | 127 | recv_right_after, recv_first_set_1, send_right_after, send_first_set_1 = add_switch_order(switch_chunk_order_recv, 128 | switch_chunk_order_send, switch_link_mapping_recv, switch_link_mapping_send, 129 | R, LL) 130 | for recv in recv_first_set_1: 131 | self.recv_first_set_1.add(recv) 132 | for send in send_first_set_1: 133 | self.send_first_set_1.add(send) 134 | return recv_right_after, send_right_after 135 | 136 | def _add_chunk_order(opt, heuristic, recv_right_after, send_right_after): 137 | if chunk_order is not None: 138 | assert chunk_time is not None 139 | assert len(chunk_order) == R 140 | assert len(chunk_order[0]) == R 141 | for r in range(R): 142 | for src in self.topology.sources(r): 143 | for l in range(self.topology.link(src,r)): 144 | this_chunk_order = chunk_order[r][src][l] 145 | max_contig = 6 146 | for i, c in enumerate(this_chunk_order): 147 | j = i + 1 148 | is_input_i = self.collective.precondition(src,c) 149 | while j 0 167 | if (recv_right_after[r][ll][(c,src)] != recv_right_after[r][ll][(o,src)] or send_right_after[src][ll_src][(c,r)] != send_right_after[src][ll_src][(o,r)]): 168 | self.is_together_set_0.add((c1,o1,r)) 169 | self.is_before_set_1.add((c,o,r)) 170 | skip_others = True 171 | if not skip_others: 172 | if heuristic == 11 and has_i_s_break: 173 | self.is_together_set_0.add((c1,o1,r)) 174 | self.is_before_set_1.add((c,o,r)) 175 | elif _should_fix_together(src,r,l,c,o): 176 | self.is_together_set_1.add((c1,o1,r)) 177 | # Max contiguity allowed = 6 178 | elif _should_try_together(src,r,c,o) and j-ii: 195 | opt.addLConstr(self.is_together[(c1,o1,r)] <= self.is_together[(c2,prev_o2,r)]) 196 | else: 197 | self.is_together_set_0.add((c1,o1,r)) 198 | self.is_before_set_1.add((c,o,r)) 199 | j = j + 1 200 | i = i + 1 201 | 202 | def alpha(r,dst): 203 | assert r != dst 204 | if self._is_relay_link(r,dst): 205 | alpha = self.topology.remote_alpha 206 | assert alpha is not None 207 | return alpha 208 | return 0 209 | 210 | def beta(r,dst): 211 | assert r != dst 212 | if self._is_relay_link(r,dst): 213 | beta = self.topology.remote_beta 214 | assert beta is not None 215 | return beta 216 | return self.topology.get_invbw(r,dst) 217 | 218 | def calc_latency(src,r,l,c): 219 | if self._is_relay_link(src,r): 220 | num_s = 0 221 | for o in range(C): 222 | if (o,src,r,l) in self.is_sent_set_1: 223 | o1,c1 = minmax(o,c) 224 | if (o1,c1,r) in self.is_together_set_1: 225 | assert (o1,c1,r) not in self.is_together 226 | num_s = num_s + 1 227 | continue 228 | if (o1,c1,r) in self.is_together_set_0: 229 | assert (o1,c1,r) not in self.is_together 230 | else: 231 | if (o1,c1,r) not in self.is_together: 232 | self.is_together[(o1,c1,r)] = opt.addVar(vtype=GRB.BINARY) 233 | lat = alpha(src,r) + beta(src,r)*(num_s + quicksum(self.is_together[(o,c,r)] if ((o,src,r,l) in self.is_sent_set_1 and (o,c,r) in self.is_together) else 0 for o in range(c)) + quicksum(self.is_together[(c,o,r)] if ((o,src,r,l) in self.is_sent_set_1 and (c,o,r) in self.is_together) else 0 for o in range(c,C))) 234 | return lat 235 | return alpha(src,r) + beta(src,r) 236 | 237 | # Set chunk is_send_set 238 | _add_chunk_sent(opt, heuristic) 239 | 240 | # Populate values 241 | for c in self.collective.chunks(): 242 | for r in self.collective.ranks(): 243 | recvd_anytime = [sum([1 if (c,src,r,l) in self.is_sent_set_1 else 0 for src in self.topology.sources(r)]) for l in range(L)] 244 | recv_IB = [sum([1 if ((c,src,r,l) in self.is_sent_set_1 and self._is_relay_link(src,r)) else 0 for src in self.topology.sources(r)]) for l in range(L)] 245 | if sum(recvd_anytime) == 0: 246 | for l in range(L): 247 | assert (c,c,r) not in self.is_together_set_1 248 | assert (c,c,r) not in self.is_together 249 | self.is_together_set_0.add((c,c,r)) 250 | else: 251 | # Will receive a chunk at most once 252 | assert sum(recvd_anytime) == 1 253 | # for l in range(L): 254 | assert (c,c,r) not in self.is_together_set_1 255 | assert (c,c,r) not in self.is_together 256 | # if recvd_anytime[l] == 1: 257 | self.is_together_set_1.add((c,c,r)) 258 | 259 | # Set ordering 260 | recv_right_after = {} 261 | send_right_after = {} 262 | recv_right_after, send_right_after = _add_switch_order( 263 | switch_chunk_order_recv, 264 | switch_chunk_order_send, 265 | switch_link_mapping_recv, 266 | switch_link_mapping_send) 267 | _add_chunk_order(opt, heuristic, recv_right_after, send_right_after) 268 | 269 | def _get_isbefore(c,o,r): 270 | if (c,o,r) in self.is_before_set_1: 271 | return True, 1 272 | elif (c,o,r) in self.is_before: 273 | return False, self.is_before[(c,o,r)] 274 | else: 275 | return True, 0 276 | 277 | def _get_istogether(c,o,r): 278 | c1,o1 = minmax(c,o) 279 | if (c1,o1,r) in self.is_together_set_1: 280 | return True, 1 281 | elif (c1,o1,r,l) in self.is_together: 282 | return False, self.is_together[(c1,o1,r)] 283 | else: 284 | return True, 0 285 | 286 | # Correctness constraints 287 | for r in self.collective.ranks(): 288 | src_r = [src for src in self.topology.sources(r)] 289 | links_r = {src: self.topology.link(src,r) for src in src_r} 290 | for c in self.collective.chunks(): 291 | opt.addLConstr(self.start[c,r] <= ST) 292 | if r not in self.spsets[c]: 293 | opt.addLConstr(self.start[c,r] == ST) 294 | for src in src_r: 295 | for l in range(L): 296 | opt.addLConstr(self.send[c,src,r,l] == SND) 297 | continue 298 | if self.collective.precondition(r, c): 299 | opt.addLConstr(self.start[c,r] == 0) 300 | else: 301 | # Bandwidth constraint 302 | for src in src_r: 303 | for l in range(links_r[src]): 304 | if (c,src,r,l) in self.is_sent_set_1: 305 | opt.addLConstr(self.start[c,r] == self.send[c,src,r,l] + calc_latency(src,r,l,c)) 306 | else: 307 | opt.addLConstr(self.send[c,src,r,l] >= SND) 308 | for l in range(links_r[src], L): 309 | opt.addLConstr(self.send[c,src,r,l] == SND) 310 | recvd_anytime = sum([sum([1 if (c,src,r,l) in self.is_sent_set_1 else 0 for l in range(links_r[src])]) for src in src_r]) 311 | if self.collective.postcondition(r, c): 312 | opt.addLConstr(self.start[c,r] <= self.time) 313 | assert recvd_anytime == 1, f'{c} {r} {self.is_sent_set_1}' 314 | else: 315 | assert recvd_anytime <= 1 316 | if recvd_anytime == 0: 317 | opt.addLConstr(self.start[c,r] >= self.time + 1) 318 | else: 319 | opt.addLConstr(self.start[c,r] <= self.time) 320 | 321 | for src in src_r: 322 | for l in range(links_r[src]): 323 | if (c,src,r,l) in self.is_sent_set_1: 324 | opt.addLConstr(self.start[c,src] <= self.start[c,r]) 325 | opt.addLConstr(self.start[c,src] <= self.send[c,src,r,l]) 326 | 327 | 328 | # Order sends from same gpu to same gpu 329 | for o in range(c): 330 | is_static_cor, is_before_cor = _get_isbefore(c,o,r) 331 | is_static_ocr, is_before_ocr = _get_isbefore(o,c,r) 332 | is_static_t_ocr, is_together_ocr = _get_istogether(o,c,r) 333 | # chunks sent together must have same send and start time 334 | if is_static_t_ocr and is_together_ocr == 1: 335 | for src in src_r: 336 | for l in range(self.topology.link(src,r)): 337 | if (c,src,r,l) in self.is_sent_set_1 and (o,src,r,l) in self.is_sent_set_1: 338 | opt.addLConstr(self.send[c,src,r,l] == self.send[o,src,r,l]) 339 | opt.addLConstr(self.start[c,r] == self.start[o,r]) 340 | elif not is_static_t_ocr: 341 | for src in src_r: 342 | for l in range(self.topology.link(src,r)): 343 | if (c,src,r,l) in self.is_sent_set_1 and (o,src,r,l) in self.is_sent_set_1: 344 | opt.addGenConstrIndicator(self.is_together[(o,c,r)], True, self.send[c,src,r,l] == self.send[o,src,r,l]) 345 | opt.addGenConstrIndicator(self.is_together[(o,c,r)], True, self.start[c,r] == self.start[o,r]) 346 | 347 | 348 | if is_static_cor and is_static_ocr and is_static_t_ocr: 349 | sent_same = any([1 if (c,src,r,l) in self.is_sent_set_1 and (o,src,r,l) in self.is_sent_set_1 else 0 for l in range(L) for src in self.topology.sources(r)]) 350 | sent_val = 1 if sent_same else 0 351 | assert is_before_cor + is_before_ocr + is_together_ocr == sent_val, f'assertion error: {is_before_cor}, {is_before_ocr}, {is_together_ocr}, {sent_val}, {sent_same}, {c}, {o}, {r}, {l}' 352 | 353 | # Bandwidth constraints based on chunk send times 354 | for src in src_r: 355 | for l in range(self.topology.link(src,r)): 356 | if (c,src,r,l) in self.is_sent_set_1 and (o,src,r,l) in self.is_sent_set_1: 357 | lat_o = calc_latency(src,r,l,o) 358 | lat_c = calc_latency(src,r,l,c) 359 | 360 | if (c,o,r) in self.is_before_set_1: 361 | opt.addLConstr(self.send[c,src,r,l] + lat_c <= self.send[o,src,r,l]) 362 | elif (c,o,r) in self.is_before: 363 | opt.addLConstr(self.send[c,src,r,l] + lat_c <= self.send[o,src,r,l] + M*(1-self.is_before[(c,o,r)])) 364 | if (o,c,r) in self.is_before_set_1: 365 | opt.addLConstr(self.send[o,src,r,l] + lat_o <= self.send[c,src,r,l]) 366 | elif (o,c,r) in self.is_before: 367 | opt.addLConstr(self.send[o,src,r,l] + lat_o <= self.send[c,src,r,l] + M*(1-self.is_before[(o,c,r)])) 368 | 369 | num_local_nodes = R // self.topology.copies 370 | # Order receives from a switch 371 | for (c,src,r,l) in self.is_sent_set_1: 372 | if (src,r) in self.topology.switches_involved_in: 373 | for (swt_i, swt_type) in self.topology.switches_involved_in[(src,r)]: 374 | srcs_check = [] 375 | if l == swt_i: 376 | for srcs, dsts, _, _, switch_name in self.topology.switches[swt_i]: 377 | if r in dsts and "in" in switch_name and src in srcs: 378 | srcs_check = srcs 379 | assert len(srcs_check)>0, f'{r} {c} {src} {l} {self.topology.switches[l]}' 380 | break 381 | lat_c = calc_latency(src,r,l,c) 382 | for o in range(c+1): 383 | for src_o in srcs_check: 384 | if src_o == src: 385 | continue 386 | if (o,src_o,r,l) in self.is_sent_set_1: 387 | if o == c: 388 | assert False 389 | lat_o = calc_latency(src_o,r,l,o) 390 | if (o,c,r,l,src_o,src) in self.recv_first_set_1: 391 | # opt.addLConstr(self.start[o,r] + lat_c <= self.start[c,r]) 392 | opt.addLConstr(self.send[o,src_o,r,l] + lat_o <= self.send[c,src,r,l]) 393 | elif (c,o,r,l,src,src_o) in self.recv_first_set_1: 394 | # opt.addLConstr(self.start[c,r] + lat_o <= self.start[o,r]) 395 | opt.addLConstr(self.send[c,src,r,l] + lat_c <= self.send[o,src_o,r,l]) 396 | else: 397 | assert False, f"no-ordering {o}, {c}, {r}, {src}, {src_o}" 398 | assert (o,c,r,l) not in self.recv_first, f'{o},{c},{r},{l}' 399 | self.recv_first[(o,c,r,l)] = opt.addVar(vtype=GRB.BINARY) 400 | opt.addLConstr(self.start[o,r] + lat_c <= self.start[c,r] + M*(1-self.recv_first[(o,c,r,l)])) 401 | opt.addLConstr(self.start[c,r] + lat_o <= self.start[o,r] + M*(self.recv_first[(o,c,r,l)])) 402 | 403 | # Order sends to a switch 404 | for (c,r,dst,l) in self.is_sent_set_1: 405 | if (r,dst) in self.topology.switches_involved: 406 | for (swt_i, swt_type) in self.topology.switches_involved[(r,dst)]: 407 | dsts_check = [] 408 | if l == swt_i: 409 | for srcs, dsts, _, _, switch_name in self.topology.switches[swt_i]: 410 | if r in srcs and "out" in switch_name and dst in dsts: 411 | dsts_check = dsts 412 | assert len(dsts_check)>0, f'{r} {c} {dst} {l} {self.topology.switches[l]}' 413 | break 414 | lat_c = calc_latency(r,dst,l,c) 415 | for o in range(c+1): 416 | for dst_o in dsts_check: 417 | if dst_o == dst: 418 | continue 419 | if (o,r,dst_o,l) in self.is_sent_set_1: 420 | lat_o = calc_latency(r,dst_o,l,o) 421 | if (o,c,r,l,dst_o,dst) in self.send_first_set_1: 422 | opt.addLConstr(self.send[o,r,dst_o,l] + lat_o <= self.send[c,r,dst,l]) 423 | elif (c,o,r,l,dst,dst_o) in self.send_first_set_1: 424 | opt.addLConstr(self.send[c,r,dst,l] + lat_c <= self.send[o,r,dst_o,l]) 425 | else: 426 | assert False 427 | assert (o,c,r,l) not in self.send_first, f'{o},{c},{r},{l}' 428 | self.send_first[(o,c,r,l)] = opt.addVar(vtype=GRB.BINARY) 429 | opt.addLConstr(self.send[o,r,dst_o,l] + lat_o <= self.send[c,r,dst,l] + M*(1-self.send_first[(o,c,r,l)])) 430 | opt.addLConstr(self.send[c,r,dst,l] + lat_c <= self.send[o,r,dst_o,l] + M*(self.send_first[(o,c,r,l)])) 431 | 432 | 433 | def optimize(self, chunk_order=None, chunk_time=None, 434 | switch_chunk_order_recv=None, switch_chunk_time_recv=None, 435 | switch_chunk_order_send=None, switch_chunk_time_send=None, 436 | nic_chunk_order_recv=None, nic_chunk_time_recv=None, 437 | nic_chunk_order_send=None, nic_chunk_time_send=None, 438 | switch_link_mapping_recv=None, switch_link_mapping_send=None): 439 | import math 440 | from time import time 441 | print(self.topology.name) 442 | chunkup = self.route_sketch.hyperparameters.chunkup 443 | print("chunkup =", chunkup) 444 | instance_name = 'taccl_{}_{}'.format(self.topology.name, self.collective.name) 445 | start_time = time() 446 | opt = Model(instance_name) 447 | self._encode(opt, chunk_order, chunk_time, 448 | switch_chunk_order_recv, switch_chunk_time_recv, 449 | switch_chunk_order_send, switch_chunk_time_send, 450 | nic_chunk_order_recv, nic_chunk_time_recv, 451 | nic_chunk_order_send, nic_chunk_time_send, 452 | switch_link_mapping_recv, switch_link_mapping_send) 453 | # opt.write(f'model_{instance_name}.lp') 454 | opt.optimize() 455 | end_time = time() 456 | print("strict time (encode+solve)", end_time-start_time, flush=True) 457 | 458 | if opt.status == GRB.INFEASIBLE: 459 | opt.computeIIS() 460 | opt.write("model.ilp") 461 | raise ValueError("Infeasible model") 462 | 463 | C = self.collective.num_chunks 464 | R = self.collective.num_nodes 465 | L = self.topology.L 466 | 467 | send_dict = defaultdict(list) 468 | SCALE_TIME = 10 469 | 470 | model_str = "" 471 | other_model_str = "" 472 | for c in range(C): 473 | for r in range(R): 474 | if self.start[c,r].X <= self.time.X + 0.005: 475 | model_str += f'start[{c},{r}]={self.start[c,r].X}\n' 476 | recv_times = defaultdict(list) 477 | chunk_path = [defaultdict(list) for c in range(C)] 478 | for src in range(R): 479 | for r in self.topology.destinations(src): 480 | for l in range(L): 481 | for c_np in chunk_order[r][src][l]: 482 | c = int(c_np) 483 | assert (c,src,r,l) in self.is_sent_set_1 484 | t = int(SCALE_TIME*self.send[c,src,r,l].X + 0.0001) 485 | transfer_str = f'{c}: {src} --{l}--> {r} t={self.send[c,src,r,l].X}\n' 486 | recv_times[t].append(transfer_str) 487 | send_dict[t].append([c,src,r,t,l]) 488 | chunk_path[c][t].append(transfer_str) 489 | for c_np in range(C): 490 | c = int(c_np) 491 | if c not in chunk_order[r][src][l]: 492 | assert (c,src,r,l) not in self.is_sent_set_1 493 | for tval in sorted(recv_times.keys()): 494 | for strval in recv_times[tval]: 495 | model_str += strval 496 | for c in range(C): 497 | for tval in sorted(chunk_path[c].keys()): 498 | for strval in chunk_path[c][tval]: 499 | other_model_str += strval 500 | for c in range(C): 501 | for o in range(c): 502 | for r in range(R): 503 | if (o,c,r) in self.is_together: 504 | if self.is_together[(o,c,r)].X >= 0.995: 505 | model_str += f'({c},{o},{r})\n' 506 | elif (o,c,r) in self.is_together_set_1: 507 | model_str += f'({c},{o},{r}) set\n' 508 | 509 | steps=[] 510 | send_times = sorted(send_dict.keys()) 511 | i = 0 512 | while(i < len(send_times)): 513 | num_sends = [[[0 for _ in range(L)] for _ in range(R)] for _ in range(R)] 514 | j = i + 1 515 | while j < len(send_times): 516 | to_break = False 517 | t_end = send_times[j] 518 | for (c,src,r,_,l) in send_dict[t_end]: 519 | for t in range(i,j): 520 | for (ci,_,ri,_,li) in send_dict[send_times[t]]: 521 | if c == ci and src == ri: 522 | to_break = True 523 | break 524 | if to_break: 525 | break 526 | if to_break: 527 | break 528 | if to_break: 529 | break 530 | j = j + 1 531 | sends = [] 532 | for k in range(i,j): 533 | sends.extend(send_dict[send_times[k]]) 534 | num_sends = [[[0 for _ in range(L)] for _ in range(R)] for _ in range(R)] 535 | for (c,src,r,_,l) in sends: 536 | num_sends[r][src][l] = num_sends[r][src][l] + 1 537 | rounds = 0 538 | for srcs, dsts, bw, l, name in self.topology.real_bandwidth_constraints(): 539 | util = 0 540 | for dst in dsts: 541 | for src in srcs: 542 | util += num_sends[dst][src][l] 543 | if rounds <= util * bw * SCALE_TIME: 544 | rounds = math.ceil(util * bw * SCALE_TIME) 545 | steps.append(Step(rounds, sorted(sends, key=lambda x: x[3]))) 546 | i = j 547 | 548 | instance = Instance( 549 | steps=len(steps), 550 | extra_rounds=0, 551 | chunks=chunkup, 552 | ) 553 | soltype = "a" if chunk_order is None else "improve" 554 | from time import time 555 | timestamp = int(time()) 556 | np.save(f'send_dict_{timestamp}.npy', send_dict) 557 | return Algorithm.make_implementation(self.collective, self.topology, instance, steps, cont=True, suffix=f'-tacclsol-{soltype}-{timestamp}') 558 | -------------------------------------------------------------------------------- /taccl/serialization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from taccl.algorithm import Algorithm, Step 5 | from taccl.topologies import TACCLTopology 6 | from taccl.instance import Instance 7 | from taccl.collectives import Collective, Chunk, Rank 8 | 9 | import json 10 | import warnings 11 | 12 | def _sccl_object_hook(o): 13 | if not 'sccl_type' in o: 14 | return o 15 | if o['sccl_type'] == 'algorithm': 16 | input_map = { int(k): set(v) for k, v in o['input_map'].items() } 17 | output_map = { int(k): set(v) for k, v in o['output_map'].items() } 18 | return Algorithm(o['name'], o['collective'], o['topology'], o['instance'], o['steps'], input_map, output_map) 19 | if o['sccl_type'] == 'step': 20 | if len(o['sends'][0]) == 6: 21 | sends = [(addr, src, dst,t,l, redop) for addr, src, dst, t,l,redop in o['sends']] 22 | elif len(o['sends'][0]) == 5: 23 | sends = [(addr, src, dst,t,l) for addr, src, dst, t,l in o['sends']] 24 | elif len(o['sends'][0]) == 4: 25 | sends = [(addr, src, dst,t) for addr, src, dst, t in o['sends']] 26 | else: 27 | sends = [(addr, src, dst) for addr, src, dst in o['sends']] 28 | return Step(o['rounds'], sends) 29 | if o['sccl_type'] == 'collective': 30 | triggers = { (int(r), int(c)): v for r, rmap in o['triggers'].items() for c, v in rmap.items() } 31 | return Collective(o['name'], o['nodes'], o['chunks'], o['ranks'], triggers) 32 | if o['sccl_type'] == 'chunk': 33 | pre = set(o['pre']) 34 | post = set(o['post']) 35 | return Chunk(pre, post, o['addr']) 36 | if o['sccl_type'] == 'rank': 37 | pre = set(o['pre']) 38 | post = set(o['post']) 39 | return Rank(pre, post, o['id']) 40 | if o['sccl_type'] == 'topology': 41 | return TACCLTopology(o['name'], o['copies'], o['ngpus_per_node'], o['node_links'], o['node_invbws'], o['remote_invbw'], o['remote_alpha'], o['remote_beta'], o['internode_conn'], o['local_switches']) 42 | if o['sccl_type'] == 'instance': 43 | return Instance(o['steps'], o['extra_rounds'], o['chunks'], o['pipeline'], o['extra_memory'], o['allow_exchange']) 44 | warnings.warn('Unhandled sccl_type in JSON') 45 | 46 | def SCCLDecoder(): 47 | return json.JSONDecoder(object_hook=_sccl_object_hook) 48 | 49 | class SCCLEncoder(json.JSONEncoder): 50 | def __init__(self): 51 | super().__init__() 52 | 53 | def default(self, o): 54 | if isinstance(o, Algorithm): 55 | input_map = { k: list(v) for k, v in o.input_map.items() } 56 | output_map = { k: list(v) for k, v in o.output_map.items() } 57 | return { 58 | 'sccl_type': 'algorithm', 59 | 'name': o.name, 60 | 'instance': o.instance, 61 | 'input_map': input_map, 62 | 'output_map': output_map, 63 | 'steps': o.steps, 64 | 'collective': o.collective, 65 | 'topology': o.topology, 66 | } 67 | if isinstance(o, Step): 68 | return { 69 | 'sccl_type': 'step', 70 | 'rounds': o.rounds, 71 | 'sends': o.sends, 72 | } 73 | if isinstance(o, Collective): 74 | triggers = {} 75 | for (r, c), v in o._triggers.items(): 76 | if not r in triggers: 77 | triggers[r] = {} 78 | triggers[r][c] = v 79 | return { 80 | 'sccl_type': 'collective', 81 | 'name': o.name, 82 | 'nodes': o.num_nodes, 83 | 'chunks': o._chunks, 84 | 'ranks': o._ranks, 85 | 'triggers': triggers, 86 | } 87 | if isinstance(o, Chunk): 88 | return { 89 | 'sccl_type': 'chunk', 90 | 'pre': list(o.precondition), 91 | 'post': list(o.postcondition), 92 | 'addr': o.address, 93 | } 94 | if isinstance(o, Rank): 95 | return { 96 | 'sccl_type': 'rank', 97 | 'pre': list(o.precondition), 98 | 'post': list(o.postcondition), 99 | 'id': o.id, 100 | } 101 | if isinstance(o, TACCLTopology): 102 | return { 103 | 'sccl_type': 'topology', 104 | 'name': o.name, 105 | 'copies': o.copies, 106 | 'ngpus_per_node' : o.ngpus_per_node, 107 | 'node_links' : o.node_links, 108 | 'node_invbws' : o.node_invbws, 109 | 'remote_invbw' : o.remote_invbw, 110 | 'remote_alpha' : o.remote_alpha, 111 | 'remote_beta' : o.remote_beta, 112 | 'internode_conn' : o.internode_conn, 113 | 'local_switches' : o.local_switches, 114 | } 115 | if isinstance(o, Instance): 116 | return { 117 | 'sccl_type': 'instance', 118 | 'steps': o.steps, 119 | 'extra_rounds': o.extra_rounds, 120 | 'chunks': o.chunks, 121 | 'pipeline': o.pipeline, 122 | 'extra_memory': o.extra_memory, 123 | 'allow_exchange': o.allow_exchange, 124 | } 125 | return json.JSONEncoder.default(self, o) 126 | 127 | def save_sccl_object(obj, filename): 128 | with open(filename, 'w') as f: 129 | f.write(SCCLEncoder().encode(obj)) 130 | 131 | def load_sccl_object(filename): 132 | with open(filename) as f: 133 | return SCCLDecoder().decode(f.read()) 134 | -------------------------------------------------------------------------------- /taccl/shortest_path_sets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | from collections import defaultdict 4 | 5 | import math 6 | 7 | def _distances(topology): 8 | # Floyd–Warshall algorithm for all-pairs shortest paths with path information 9 | # Modified to track all shortest paths 10 | nodes = range(topology.num_nodes()) 11 | dist = [[math.inf for _ in nodes] for _ in nodes] 12 | next = [[set() for _ in nodes] for _ in nodes] 13 | for dst in nodes: 14 | for src in topology.sources(dst): 15 | dist[src][dst] = 1 16 | next[src][dst].add(dst) 17 | for node in nodes: 18 | dist[node][node] = 0 19 | next[node][node].add(node) 20 | for k in nodes: 21 | for i in nodes: 22 | for j in nodes: 23 | if dist[i][j] > dist[i][k] + dist[k][j]: 24 | dist[i][j] = dist[i][k] + dist[k][j] 25 | next[i][j] = set() 26 | for l in next[i][k]: 27 | next[i][j].add(l) 28 | elif dist[i][j] == dist[i][k] + dist[k][j]: 29 | for l in next[i][k]: 30 | next[i][j].add(l) 31 | 32 | return dist, next 33 | 34 | def shortest_path_sets(topology, collective): 35 | dist, next = _distances(topology) 36 | nodes = range(topology.num_nodes()) 37 | spsets = {} 38 | for id, chunk in enumerate(collective._chunks): 39 | spset = set() 40 | for u in chunk.precondition: 41 | for v in chunk.postcondition: 42 | curr = next[u][v] 43 | if not curr: 44 | continue 45 | spset.add(u) 46 | while not v in curr: 47 | spset.update(curr) 48 | curr = set().union(*[next[x][v] for x in curr]) 49 | spset.update(curr) 50 | spsets[id] = spset 51 | 52 | return spsets 53 | -------------------------------------------------------------------------------- /taccl/topologies/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from .generic import * 5 | from .route_sketch import * 6 | from .topology import * -------------------------------------------------------------------------------- /taccl/topologies/generic.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import json 5 | from .topology import NodeTopology 6 | 7 | def validate_and_modify_topo(topo_json, check_links=True): 8 | assert "name" in topo_json, "Provide a name in the topo file" 9 | assert "gpus_per_node" in topo_json 10 | assert "alpha" in topo_json 11 | devices = topo_json["gpus_per_node"] 12 | assert devices > 0 13 | if check_links: 14 | assert "links" in topo_json 15 | assert "invbws" in topo_json 16 | assert "betas" in topo_json 17 | assert "node_invbws_list" not in topo_json 18 | assert "node_betas_list" not in topo_json 19 | assert len(topo_json["links"]) == devices 20 | assert len(topo_json["betas"]) == devices 21 | assert len(topo_json["invbws"]) == devices 22 | for l in topo_json["links"]: 23 | assert isinstance(l, list) 24 | assert len(l) == devices 25 | for l in topo_json["invbws"]: 26 | assert isinstance(l, list) 27 | assert len(l) == devices 28 | for l in topo_json["betas"]: 29 | assert isinstance(l, list) 30 | assert len(l) == devices 31 | else: 32 | assert "links" not in topo_json 33 | assert "invbws" not in topo_json 34 | assert "node_invbws_list" in topo_json 35 | assert "node_betas_list" in topo_json 36 | if ("nics_per_node" in topo_json): 37 | assert "remote_alpha" in topo_json 38 | assert "remote_beta" in topo_json 39 | assert "remote_invbw" in topo_json 40 | else: 41 | topo_json["nics_per_node"] = -1 42 | topo_json["remote_alpha"] = -1 43 | topo_json["remote_beta"] = -1 44 | topo_json["remote_invbws"] = -1 45 | return topo_json 46 | 47 | def custom(topo_file): 48 | topo_json = json.load(topo_file) 49 | topo_json = validate_and_modify_topo(topo_json, check_links=True) 50 | gpus_per_node = topo_json["gpus_per_node"] 51 | links = topo_json["links"] 52 | invbws = topo_json["invbws"] 53 | nics_per_node = topo_json["nics_per_node"] 54 | remote_invbw = topo_json["remote_invbw"] 55 | remote_alpha = topo_json["remote_alpha"] 56 | remote_beta = topo_json["remote_beta"] 57 | name = topo_json["name"] 58 | return NodeTopology(f'Custom-{name}-(n={gpus_per_node})', links, alpha, betas, invbws, nics_per_node, remote_invbw, remote_alpha, remote_beta) 59 | 60 | 61 | def hub_and_spoke(topo_file): 62 | print("topo_file:", topo_file) 63 | f = open(topo_file, "r") 64 | topo_json = json.load(f) 65 | gpus_per_node = topo_json["gpus_per_node"] 66 | assert len(topo_json["node_invbws_list"]) == 1 67 | node_invbw = topo_json["node_invbws_list"][0] 68 | assert len(topo_json["node_betas_list"]) == 1 69 | node_beta = topo_json["node_betas_list"][0] 70 | alpha = topo_json["alpha"] 71 | links = [[0 if x==y else 1 for y in range(gpus_per_node)] for x in range(gpus_per_node)] 72 | betas = [[0 if x==y else node_beta for y in range(gpus_per_node)] for x in range(gpus_per_node)] 73 | invbws = [[0 if x==y else node_invbw for y in range(gpus_per_node)] for x in range(gpus_per_node)] 74 | nics_per_node = topo_json["nics_per_node"] 75 | remote_invbw = topo_json["remote_invbw"] 76 | remote_alpha = topo_json["remote_alpha"] 77 | remote_beta = topo_json["remote_beta"] 78 | name = topo_json["name"] 79 | return NodeTopology(f'HubAndSpoke-{name}-(n={gpus_per_node})', links, alpha, betas, invbws, nics_per_node, remote_invbw, remote_alpha, remote_beta) 80 | 81 | 82 | def dgx2(topo_file): 83 | print("topo_file:", topo_file) 84 | f = open(topo_file, "r") 85 | topo_json = json.load(f) 86 | topo_json["nics_per_node"] = 8 87 | topo_json["gpus_per_node"] == 16 88 | print("Fixing nics_per_node and gpus_per_node. This will overwrite any values provided") 89 | topo_json = validate_and_modify_topo(topo_json, check_links=False) 90 | assert len(topo_json["node_invbws_list"]) == 1 91 | assert len(topo_json["node_betas_list"]) == 1 92 | node_invbw = int(topo_json["node_invbws_list"][0]) 93 | node_beta = topo_json["node_betas_list"][0] 94 | alpha = topo_json["alpha"] 95 | gpus_per_node = topo_json["gpus_per_node"] 96 | nics_per_node = topo_json["nics_per_node"] 97 | remote_invbw = topo_json["remote_invbw"] 98 | remote_alpha = topo_json["remote_alpha"] 99 | remote_beta = topo_json["remote_beta"] 100 | name = topo_json["name"] 101 | links = [[0 if x==y else 1 for y in range(gpus_per_node)] for x in range(gpus_per_node)] 102 | betas = [[0 if x==y else node_beta for y in range(gpus_per_node)] for x in range(gpus_per_node)] 103 | invbws = [[0 if x==y else node_invbw for y in range(gpus_per_node)] for x in range(gpus_per_node)] 104 | return NodeTopology(f'DGX2-{name}-(n={gpus_per_node})', links, alpha, betas, invbws, nics_per_node, remote_invbw, remote_alpha, remote_beta) 105 | 106 | 107 | def ndv2(topo_file): 108 | print("topo_file:", topo_file) 109 | f = open(topo_file, "r") 110 | topo_json = json.load(f) 111 | f.close() 112 | topo_json["nics_per_node"] = 1 113 | topo_json["gpus_per_node"] == 8 114 | print("Fixing nics_per_node and gpus_per_node. This will overwrite any values provided") 115 | topo_json = validate_and_modify_topo(topo_json, check_links=False) 116 | assert len(topo_json["node_invbws_list"]) == 2 117 | 118 | # Link connection matrix 119 | links = [ 120 | #0 1 2 3 4 5 6 7 121 | [0, 1, 1, 1, 1, 0, 0, 0], 122 | [1, 0, 1, 1, 0, 1, 0, 0], 123 | [1, 1, 0, 1, 0, 0, 1, 0], 124 | [1, 1, 1, 0, 0, 0, 0, 1], 125 | [1, 0, 0, 0, 0, 1, 1, 1], 126 | [0, 1, 0, 0, 1, 0, 1, 1], 127 | [0, 0, 1, 0, 1, 1, 0, 1], 128 | [0, 0, 0, 1, 1, 1, 1, 0] 129 | ] 130 | 131 | alpha = topo_json["alpha"] 132 | 133 | # NVLink beta for each link 134 | beta_m1 = topo_json["node_betas_list"][0] 135 | beta_m2 = topo_json["node_betas_list"][1] 136 | betas = [ 137 | [0, beta_m1, beta_m2, beta_m2, beta_m1, 0, 0, 0], 138 | [beta_m1, 0, beta_m2, beta_m1, 0, beta_m2, 0, 0], 139 | [beta_m2, beta_m2, 0, beta_m1, 0, 0, beta_m1, 0], 140 | [beta_m2, beta_m1, beta_m1, 0, 0, 0, 0, beta_m2], 141 | [beta_m1, 0, 0, 0, 0, beta_m1, beta_m2, beta_m2], 142 | [0, beta_m2, 0, 0, beta_m1, 0, beta_m2, beta_m1], 143 | [0, 0, beta_m1, 0, beta_m2, beta_m2, 0, beta_m1], 144 | [0, 0, 0, beta_m2, beta_m2, beta_m1, beta_m1, 0] 145 | ] 146 | 147 | # NVLink bandwidth for each link 148 | invbw1 = int(topo_json["node_invbws_list"][0]) 149 | invbw2 = int(topo_json["node_invbws_list"][1]) 150 | invbws = [ 151 | [0, invbw1, invbw2, invbw2, invbw1, 0, 0, 0], 152 | [invbw1, 0, invbw2, invbw1, 0, invbw2, 0, 0], 153 | [invbw2, invbw2, 0, invbw1, 0, 0, invbw1, 0], 154 | [invbw2, invbw1, invbw1, 0, 0, 0, 0, invbw2], 155 | [invbw1, 0, 0, 0, 0, invbw1, invbw2, invbw2], 156 | [0, invbw2, 0, 0, invbw1, 0, invbw2, invbw1], 157 | [0, 0, invbw1, 0, invbw2, invbw2, 0, invbw1], 158 | [0, 0, 0, invbw2, invbw2, invbw1, invbw1, 0] 159 | ] 160 | # Ex. for 1 MB data chunks, the following matrix denotes node invbws 161 | # invbws = [ 162 | # [0, 23, 46, 46, 23, 0, 0, 0], 163 | # [23, 0, 46, 23, 0, 46, 0, 0], 164 | # [46, 46, 0, 23, 0, 0, 23, 0], 165 | # [46, 23, 23, 0, 0, 0, 0, 46], 166 | # [23, 0, 0, 0, 0, 23, 46, 46], 167 | # [0, 46, 0, 0, 23, 0, 46, 23], 168 | # [0, 0, 23, 0, 46, 46, 0, 23], 169 | # [0, 0, 0, 46, 46, 23, 23, 0] 170 | # ] 171 | nics_per_node = topo_json["nics_per_node"] 172 | remote_invbw = topo_json["remote_invbw"] 173 | remote_alpha = topo_json["remote_alpha"] 174 | remote_beta = topo_json["remote_beta"] 175 | name = topo_json["name"] 176 | 177 | return NodeTopology(f'NDv2-{name}', links, alpha, betas, invbws, nics_per_node, remote_invbw, remote_alpha, remote_beta) 178 | -------------------------------------------------------------------------------- /taccl/topologies/route_sketch.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | @dataclass 4 | class IntraNode: 5 | strategy: str 6 | 7 | @dataclass 8 | class IntraNode_Switch(IntraNode): 9 | switches: list 10 | switch_hyperedge_strategy: list 11 | 12 | @dataclass 13 | class IntraNode_RelaySwitch(IntraNode): 14 | relayed_switch_conn: dict 15 | 16 | @dataclass 17 | class InterNode: 18 | pass 19 | 20 | @dataclass 21 | class InterNode_Switch(InterNode): 22 | switches: list 23 | switch_hyperedge_strategy: list 24 | 25 | @dataclass 26 | class InterNode_Relay(InterNode): 27 | internode_conn: dict 28 | gpus_to_sender_rev_map: dict 29 | enforce_ordering: bool 30 | 31 | @dataclass 32 | class MultiNode: 33 | strategy: list 34 | nnodes: list 35 | group_size: list 36 | 37 | @dataclass 38 | class Symmetry: 39 | offsets: list 40 | 41 | @dataclass 42 | class HyperParameter: 43 | chunkup: int 44 | heuristic: int 45 | 46 | @dataclass 47 | class RouteSketch: 48 | intranode: IntraNode 49 | internode: InterNode 50 | multinode: MultiNode 51 | symmetry: Symmetry 52 | hyperparameters: HyperParameter -------------------------------------------------------------------------------- /taccl/topologies/topology.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | from collections import defaultdict 4 | import math 5 | 6 | class NodeTopology(object): 7 | def __init__(self, name, links, alpha, betas, invbws, nics_per_node, remote_invbw, remote_alpha, remote_beta): 8 | self.name = name 9 | self.links = links 10 | self.alpha = alpha 11 | self.betas = betas 12 | self.invbws= invbws 13 | self.nics_per_node = nics_per_node 14 | self.remote_invbw = remote_invbw 15 | self.remote_alpha = remote_alpha 16 | self.remote_beta = remote_beta 17 | 18 | def _add_ext_edge(links, ngpus_per_node, ngpus, copies, internode_conn, remote_link, link_split=None, multinode_split=1): 19 | new_links = [ 20 | [ 21 | links[dst % ngpus_per_node][src % ngpus_per_node] 22 | if src // ngpus_per_node == dst // ngpus_per_node 23 | else 0 24 | for src in range(ngpus) 25 | ] for dst in range(ngpus) 26 | ] 27 | 28 | for i in range(copies): 29 | for j in range(copies): 30 | if i == j: 31 | continue 32 | for s in internode_conn: 33 | for r in internode_conn[s]: 34 | src = int(s) + i * ngpus_per_node 35 | dst = r + j * ngpus_per_node 36 | if link_split is not None: 37 | rlink = remote_link * link_split[s] * multinode_split 38 | else: 39 | 40 | rlink = remote_link * multinode_split 41 | new_links[dst][src] = rlink 42 | return new_links 43 | 44 | def _add_ext_switches(invbws, ngpus_per_node): 45 | swt = [] 46 | for r in range(len(invbws)): 47 | c = r // ngpus_per_node 48 | node = r % ngpus_per_node 49 | dsts = [dst for dst in range(len(invbws)) if ((r//ngpus_per_node != dst//ngpus_per_node) and (invbws[dst][r] > 0))] 50 | srcs = [src for src in range(len(invbws)) if ((r//ngpus_per_node != src//ngpus_per_node) and (invbws[r][src] > 0))] 51 | invbw = None 52 | if len(dsts) != 0: 53 | invbw = invbws[dsts[0]][r] 54 | if len(srcs) != 0: 55 | invbw = invbws[r][srcs[0]] 56 | if invbw is not None: 57 | swt.append(([r], dsts, 1, invbw, f'copy_{c}_node_{node}_out_remot')) 58 | swt.append((srcs, [r], 1, invbw, f'copy_{c}_node_{node}_in_remot')) 59 | return swt 60 | 61 | def _make_switch(switches, node_beta, copies, ngpus_per_node): 62 | new_switches = [] 63 | num_switches = len(switches) 64 | for i in range(num_switches): 65 | swt = [] 66 | for c in range(copies): 67 | for node in switches[i]: 68 | dist_node = node + c * ngpus_per_node 69 | dist_others = [other + c * ngpus_per_node for other in switches[i] if other != node] 70 | invbw = node_beta[dist_others[0]][dist_node] 71 | for o in dist_others: 72 | assert node_beta[o][dist_node] == invbw 73 | swt.append(([dist_node], dist_others, 1, invbw, f'copy_{c}_node_{node}_swt_{i}_out_local')) 74 | swt.append((dist_others, [dist_node], 1, invbw, f'copy_{c}_node_{node}_swt_{i}_in_local')) 75 | new_switches.append(swt) 76 | return new_switches 77 | 78 | # TODO: whether to make relays switch and have same b/w or have them share b/w 79 | # will be decided by multinode-sketch and we need to handle the way links b/w is assigned 80 | # Right now, we assume that there will only be a single group of nodes, allowing the link to 81 | # be = group_size * link 82 | class TACCLTopology(object): 83 | def __init__(self, name, copies, ngpus_per_node, node_links, 84 | node_invbws, remote_invbw, remote_alpha, remote_beta, 85 | internode_conn, switches=[]): 86 | self.name = name 87 | self.copies = copies 88 | self.ngpus_per_node = ngpus_per_node 89 | self.ngpus = copies * ngpus_per_node 90 | self.node_links = node_links 91 | self.node_invbws = node_invbws 92 | self.invbws = node_invbws 93 | self.internode_conn = internode_conn 94 | self.base_gpus = [] 95 | self.local_switches = switches 96 | self.remote_beta = remote_beta 97 | ext_switches = [] 98 | base_gpu = 0 99 | for c in range(copies): 100 | self.base_gpus.append(base_gpu) 101 | base_gpu += ngpus_per_node 102 | self.base_gpus.append(base_gpu) 103 | links = node_links 104 | if copies > 1: 105 | links = _add_ext_edge(node_links, ngpus_per_node, self.ngpus, copies, internode_conn, 1) 106 | self.invbws = _add_ext_edge(node_invbws, ngpus_per_node, self.ngpus, copies, internode_conn, remote_invbw) 107 | if copies > 2: 108 | ext_switches = _add_ext_switches(self.invbws, ngpus_per_node) 109 | 110 | self.links = links 111 | self.remote_invbw = remote_invbw 112 | self.remote_alpha = remote_alpha 113 | 114 | self.switches = _make_switch(switches, self.invbws, copies, ngpus_per_node) 115 | if len(ext_switches) > 0: 116 | if len(self.switches) == 0: 117 | self.switches.append(ext_switches) 118 | else: 119 | self.switches[0].extend(ext_switches) 120 | self.num_switches = len(switches) 121 | 122 | # switches = [switch1, switch2, ...] 123 | # Have all unique switch src-dsts in switch1 124 | # If there are 6 switches as in nvswitch in DGX2, we separate them in switch1, switch2, ..., switch6 125 | # switch1 = [([src0], [dsts], 1, invbw, "out-1"), ([srcs], [dst0], 1, invbw, "in-1"),...] 126 | # switch2 = [([src0], [dsts], 1, invbw, "out-2"), ([srcs], [dst0], 1, invbw, "in-2"),...] 127 | for switch in self.switches: 128 | # print("switch:", switch) 129 | for srcs, dsts, lk, invbw, switch_name in switch: 130 | if lk == 0: 131 | raise ValueError(f'Switch {switch_name} has zero bandwidth, but switch bandwidths must be strictly positive. Please encode connectedness in links.') 132 | if lk < 0: 133 | raise ValueError(f'Switch {switch_name} has a negative inverse bandwidth of {invbw}. Bandwidth must be strictly positive.') 134 | self.bw_dist, _ = self.set_bw_distances() 135 | self.set_switches_involved() 136 | self.set_L() 137 | 138 | def gpu_to_node(self, gpu): 139 | assert gpu < self.ngpus 140 | assert gpu >= 0 141 | for c in range(self.copies): 142 | if gpu >= self.base_gpus[c] and gpu < self.base_gpus[c+1]: 143 | return c 144 | 145 | def node_to_gpu(self, node): 146 | assert node < self.copies 147 | assert node >= 0 148 | return sum(self.base_gpus[:node+1]) 149 | 150 | def sources(self, dst): 151 | for src, bw in enumerate(self.links[dst]): 152 | if bw > 0: 153 | yield src 154 | 155 | def destinations(self, src): 156 | for dst, links in enumerate(self.links): 157 | bw = links[src] 158 | if bw > 0: 159 | yield dst 160 | 161 | def link(self, src, dst): 162 | return self.links[dst][src] 163 | 164 | def get_invbw(self, src, dst): 165 | return self.invbws[dst][src] 166 | 167 | def num_nodes(self): 168 | return len(self.links) 169 | 170 | def nodes(self): 171 | return range(self.num_nodes()) 172 | 173 | # constraints using number of links 174 | def bandwidth_constraints(self): 175 | for dst, dst_links in enumerate(self.links): 176 | for src, lk in enumerate(dst_links): 177 | if lk > 0: 178 | yield ([src], [dst], lk, f'{src}→{dst}') 179 | for switch in self.switches: 180 | for srcs, dsts,lk, _, switch_name in switch: 181 | yield (srcs, dsts, lk, switch_name) 182 | 183 | # constraints using actual bandwidth 184 | def real_bandwidth_constraints(self): 185 | for dst, dst_links in enumerate(self.invbws): 186 | for src, invbw in enumerate(dst_links): 187 | if invbw > 0: 188 | for l in range(self.link(src,dst)): 189 | yield ([src], [dst], invbw, l, f'{src}→{dst}') 190 | for swt_i, switch in enumerate(self.switches): 191 | for srcs, dsts, _, invbw, switch_name in switch: 192 | yield (srcs, dsts, invbw, swt_i, switch_name) 193 | 194 | def set_bw_distances(self): 195 | if self.remote_beta is None: 196 | return None, None 197 | # Floyd–Warshall algorithm for all-pairs shortest paths with path information 198 | # Modified to track all shortest paths 199 | nodes = range(self.num_nodes()) 200 | dist = [[math.inf for _ in nodes] for _ in nodes] 201 | next = [[set() for _ in nodes] for _ in nodes] 202 | for dst in nodes: 203 | for src in self.sources(dst): 204 | dist[src][dst] = self.invbws[dst][src] 205 | next[src][dst].add(dst) 206 | for node in nodes: 207 | dist[node][node] = 0 208 | next[node][node].add(node) 209 | for k in nodes: 210 | for i in nodes: 211 | for j in nodes: 212 | if dist[i][j] >= dist[i][k] + dist[k][j]: 213 | dist[i][j] = dist[i][k] + dist[k][j] 214 | next[i][j].update(next[i][k]) 215 | return dist, next 216 | 217 | def set_switches_involved(self): 218 | self.switches_involved = defaultdict(list) 219 | self.switches_involved_in = defaultdict(list) 220 | self.switch_dst_dict = defaultdict(dict) 221 | self.switch_src_dict = defaultdict(dict) 222 | for i, switch in enumerate(self.switches): 223 | for srcs, dsts, _, _, switch_name in switch: 224 | if "out" in switch_name: 225 | assert len(srcs) == 1 226 | self.switch_dst_dict[srcs[0]][(i,switch_name[:-6])] = dsts 227 | for dst in dsts: 228 | self.switches_involved[(srcs[0],dst)].append((i,switch_name[:-6])) 229 | if "in" in switch_name: 230 | assert len(dsts) == 1 231 | self.switch_src_dict[dsts[0]][(i,switch_name[:-6])] = srcs 232 | for src in srcs: 233 | self.switches_involved_in[(src,dsts[0])].append((i,switch_name[:-6])) 234 | # print("si: ", self.switches_involved) 235 | # print("sii: ", self.switches_involved_in) 236 | 237 | def reverse_links(self): 238 | num_nodes = self.num_nodes() 239 | new_links = [[None for y in range(num_nodes)] for x in range(num_nodes)] 240 | for x in range(num_nodes): 241 | for y in range(num_nodes): 242 | new_links[x][y] = self.links[y][x] 243 | new_invbws = [[None for y in range(num_nodes)] for x in range(num_nodes)] 244 | for x in range(num_nodes): 245 | for y in range(num_nodes): 246 | new_invbws[x][y] = self.invbws[y][x] 247 | new_switches = [] 248 | for i in range(self.num_switches): 249 | new_swt = [] 250 | for swt in self.switches: 251 | for srcs, dsts, lk, invbw, name in swt: 252 | if "out" in name: 253 | new_name = name.replace('out','in') 254 | elif "in" in name: 255 | new_name = name.replace('in','out') 256 | new_swt.append((dsts,srcs,lk,invbw,new_name)) 257 | new_switches.append(new_swt) 258 | 259 | self.links = new_links 260 | self.invbws = new_invbws 261 | self.switches = new_switches 262 | self.set_switches_involved() 263 | 264 | def set_L(self): 265 | R = self.num_nodes() 266 | L = 0 267 | for src in range(R): 268 | for dst in self.destinations(src): 269 | if self.link(src,dst) > L: 270 | L = self.link(src,dst) 271 | self.L = L 272 | -------------------------------------------------------------------------------- /taccl/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from collections import defaultdict 5 | import enum 6 | import pickle as pkl 7 | import sys 8 | 9 | # Creating new datastructures from order information 10 | # Order recvs coming into and going out from a GPU connected to a switch 11 | # Input: 12 | # switch_chunk_order_recv[r][ll] : [(c1,src1), ...] in order 13 | # switch_chunk_order_send[r][ll] : [(c1,dst1), ...] in order 14 | # Output: 15 | # recv_right_after[r][l][(c,srci)] = (o,srcj) => GPU r over switch l receives o from srcj right after receiving c from srci 16 | # send_right_after[r][l][(c,dsti)] = (o,dstj) => GPU r over switch l sends o to dstj right after sending c to dsti 17 | # (c,o,r,l,srci,srcj) \in recv_first_set_1 => c is recvd on r from srci anytime before o is recvd on r from srcj 18 | # (c,o,r,l,dsti,dstj) \in send_first_set_1 => c is sent from r to dsti anytime before o is sent from r to dstj 19 | def add_switch_order(switch_chunk_order_recv, switch_chunk_order_send, switch_link_mapping_recv, switch_link_mapping_send, R, LL): 20 | recv_right_after = [[defaultdict() for l in range(LL)] for r in range(R)] 21 | send_right_after = [[defaultdict() for l in range(LL)] for r in range(R)] 22 | recv_first_set_1 = set() 23 | send_first_set_1 = set() 24 | 25 | if switch_chunk_order_recv is not None: 26 | assert len(switch_chunk_order_recv) == R 27 | for r in range(R): 28 | for ll, chunk_order_recv in enumerate(switch_chunk_order_recv[r]): 29 | for i, (c,srci) in enumerate(chunk_order_recv): 30 | l = switch_link_mapping_recv[r][ll] 31 | j = i + 1 32 | has_after = False 33 | while j