├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── NOTICE ├── README.md ├── THIRD-PARTY-LICENSES ├── docs └── images │ └── overview_horizontal.png ├── dynapipe ├── __init__.py ├── data_opt │ ├── Makefile │ ├── __init__.py │ ├── cost_models.py │ ├── dp_helper.cpp │ └── optimizer.py ├── memory_opt │ ├── allocation_simulator.py │ ├── cuda_caching_allocator.cpp │ ├── cuda_caching_allocator.py │ ├── host_caching_allocator.py │ ├── setup.py │ ├── traceback.h │ └── utils.py ├── model.py ├── pipe │ ├── data_loader.py │ ├── executor.py │ ├── instruction_optimizer.py │ ├── instructions.py │ ├── kv_redis.py │ └── utils.py ├── schedule_opt │ ├── __init__.py │ ├── cyclic_schedule.py │ ├── execution_planner.py │ ├── fifo_schedule.py │ ├── ofob_schedule.py │ ├── schedule_common.py │ ├── wait_free_cyclic_schedule.py │ └── wait_free_schedule.py └── utils │ ├── logger.py │ └── memory_utils.py ├── requirements.txt ├── scripts ├── check_dataloader_logs.py ├── check_executor_logs.py ├── estimate_memory_usage.py ├── generate_random_mb_spec.py ├── plot_microbatch_permutation.py ├── pre-commit.hook ├── run_format.sh ├── simulation │ ├── compare_batching_methods.py │ ├── schedule_under_dynamic_mb.py │ └── shift_trace_json.py └── validate_execution_plans.py ├── setup.py └── tests ├── test_dataloader └── test_dataloader.py ├── test_dev_assign_validation.py ├── test_instruction_optimizer.py ├── test_instructions.py ├── test_kv_store.py ├── test_logger.py ├── test_memory_opt ├── test_cuda_memory_stats.py └── test_host_caching_allocator.py └── test_scheduler └── test_scheduler.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | ## Optimizing Multi-task Training through Dynamic Pipelines 5 | 6 | Official repository for the paper *DynaPipe: Optimizing Multi-task Training through Dynamic Pipelines* ([Paper](https://arxiv.org/abs/2311.10418)). 7 | 8 | During multi-task training, the model commonly receives input sequences of highly different lengths due to the diverse contexts of different tasks. Padding (to the same sequence length) or packing (short examples into long sequences of the same length) is usually adopted to prepare input samples for model training, which is nonetheless not space or computation efficient. This project adopts a dynamic micro-batching approach to tackle sequence length variation. Each input global batch is split into multiple variable-length micro-batches, each of which comprises a (potentially different) number of samples of similar sequence lengths. These micro-batches are efficiently organized into pipelines, facilitating efficient 3D-parallel (data, tensor and pipeline) multi-task model training. 9 | 10 | Main features of this project include: 11 | 12 | * An efficient dynamic programming algorithm to compute the optimal micro-batching plan for each input global batch. 13 | * A pipeline schedule robust to variable-sized micro-batches, minimizing pipeline bubbles. 14 | * A pipeline executor supporting highly dynamic pipelines (the pipeline schedule, the size and number of micro-batches can vary each iteration), based on an instruction-based abstraction of pipeline operations. 15 | * Overlapped execution plan generation with model training. 16 | 17 | 18 | ## System Diagram 19 | ![System Diagram](docs/images/overview_horizontal.png) 20 | 21 | ## Getting Started 22 | 23 | ### Dependencies 24 | #### Redis 25 | The distributed instruction store uses Redis as the underlying key-value store. [Redis server](https://redis.io/) needs to be installed on machines participating in training. Our code will setup and initialize a Redis server automatically. 26 | 27 | *Note*: The Redis server is not protected by authentication and may pose security risks. Please make sure that the code is only run in a secure environment. 28 | 29 | #### Python Dependencies 30 | Please see [requirements.txt](requirements.txt) for the required Python packages. Install them by running 31 | 32 | ```pip3 install -r requirements.txt``` 33 | 34 | ### Installation 35 | Clone this repository and run 36 | 37 | ```pip3 install -e .``` 38 | 39 | Then, build the C++ extensions by running 40 | 41 | ``` 42 | cd dynapipe/data_opt 43 | make 44 | cd ../memory_opt 45 | python3 setup.py build 46 | ``` 47 | 48 | 49 | ### Pipeline Instructions 50 | To use this project, the Pipeline Instructions (defined [here](/dynapipe/pipe/instructions.py)) needs to be implemented using the intented training framework (e.g., Megatron-LM). A reference implementation of the instructions in Megatron-LM can be found [here](https://github.com/chenyu-jiang/Megatron-LM/blob/dynapipe/megatron/pipeline_executor.py). 51 | 52 | ### Using this project 53 | 54 | *Please note that this project is experimental and only tested on [integrating with Megatron-LM](https://github.com/chenyu-jiang/Megatron-LM) (please refer to the linked repository for detailed usage).* 55 | 56 | This project interacts with the training framework mainly through the following two interfaces: 57 | 58 | #### Data Loader 59 | 60 | We wrap the micro-batch splitting and execution plan generation process into a `DynaPipeDataLoader`. It takes the normal PyTorch data loader arguments with a few additional ones. Please see [here](/dynapipe/pipe/data_loader.py) for the full list of arguments. The returning iterator will generate tuples of micro-batched data and the corresponding execution plan for each iteraton. This iterator is to be used by the pipeline executor. See [here](https://github.com/chenyu-jiang/Megatron-LM/blob/dynapipe/megatron/data/data_samplers.py) for an example of using the `DynaPipeDataLoader` in Megatron-LM. 61 | 62 | #### Pipeline Executor 63 | 64 | The pipeline executor simply reads in execution plans and calls the Pipeline Instruction Implementations. These implementations are registered to the executor through the [`register_handler` function](https://github.com/chenyu-jiang/Megatron-LM/blob/3def65d56515b0b3e617b47abb86088d79b15c9c/megatron/pipeline_executor.py#L393C17-L393C17). To run the pipeline executor, simply call the `execute` function with the corresponding execution plan in each iteration. See [here](https://github.com/chenyu-jiang/Megatron-LM/blob/3def65d56515b0b3e617b47abb86088d79b15c9c/megatron/training.py#L705C9-L705C9) for an example of using the pipeline executor in Megatron-LM. 65 | 66 | #### Environment Variables 67 | 68 | Except for the above two interfaces, this project can also be configured through the following environment variables: 69 | 70 | * `DYNAPIPE_KV_HOST`: The host IP of the Redis kv store server. Default to 'localhost' (requried for multi-node training). 71 | * `DYNAPIPE_KV_PORT`: The port for the Redis kv store server. Default to 29500. 72 | * `DYNAPIPE_DEBUG`: Logging level. Default to 'INFO'. Set to 'DEBUG' for more detailed logging. 73 | * `DYNAPIPE_LOGGING_DEBUG_DIR`: The directory to store all generated logs. 74 | * `DYNAPIPE_DEBUG_DUMP_EP_STATS`: if set true, dump the generated execution plans, seen sequence lengths, shapes of the generated micro-batches, estimated memory and simulated traces for each iteration during training. Used for debugging and for collecting statistics during our experiments. 75 | * `DYNAPIPE_DEBUG_DUMP_EP_PREFIX`: the directory for dumping the above artifacts. 76 | 77 | ## Code Structure 78 | ``` 79 | ├── dynapipe 80 | │ : main source folder 81 | │ ├── data_opt 82 | │ │ : code for micro-batch splitting and cost models 83 | │ ├── memory_opt 84 | │ │ : contains the modified cuda caching memory allocator 85 | │ │ from PyTorch 86 | │ ├── pipe 87 | │ │ : contains implementation of pipeline instructions, 88 | │ │ executor, and the distributed instruction store 89 | │ ├── schedule_opt 90 | │ │ : code for computing pipeline schedule 91 | │ └── utils 92 | │ : other util codes like logger 93 | ├── scripts 94 | │ : utility scripts for various purposes 95 | ├── tests 96 | │ : unit tests of different modules 97 | ``` 98 | 99 | 100 | ## Security 101 | 102 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. 103 | 104 | ## License 105 | 106 | This project is licensed under the Apache-2.0 License. -------------------------------------------------------------------------------- /THIRD-PARTY-LICENSES: -------------------------------------------------------------------------------- 1 | ** DeepSpeed; version v0.10.0 -- https://github.com/microsoft/DeepSpeed/tree/v0.10.0 2 | 3 | Apache License 4 | Version 2.0, January 2004 5 | http://www.apache.org/licenses/ 6 | 7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 8 | 9 | 1. Definitions. 10 | 11 | "License" shall mean the terms and conditions for use, reproduction, 12 | and distribution as defined by Sections 1 through 9 of this document. 13 | 14 | "Licensor" shall mean the copyright owner or entity authorized by 15 | the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all 18 | other entities that control, are controlled by, or are under common 19 | control with that entity. For the purposes of this definition, 20 | "control" means (i) the power, direct or indirect, to cause the 21 | direction or management of such entity, whether by contract or 22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 23 | outstanding shares, or (iii) beneficial ownership of such entity. 24 | 25 | "You" (or "Your") shall mean an individual or Legal Entity 26 | exercising permissions granted by this License. 27 | 28 | "Source" form shall mean the preferred form for making modifications, 29 | including but not limited to software source code, documentation 30 | source, and configuration files. 31 | 32 | "Object" form shall mean any form resulting from mechanical 33 | transformation or translation of a Source form, including but 34 | not limited to compiled object code, generated documentation, 35 | and conversions to other media types. 36 | 37 | "Work" shall mean the work of authorship, whether in Source or 38 | Object form, made available under the License, as indicated by a 39 | copyright notice that is included in or attached to the work 40 | (an example is provided in the Appendix below). 41 | 42 | "Derivative Works" shall mean any work, whether in Source or Object 43 | form, that is based on (or derived from) the Work and for which the 44 | editorial revisions, annotations, elaborations, or other modifications 45 | represent, as a whole, an original work of authorship. For the purposes 46 | of this License, Derivative Works shall not include works that remain 47 | separable from, or merely link (or bind by name) to the interfaces of, 48 | the Work and Derivative Works thereof. 49 | 50 | "Contribution" shall mean any work of authorship, including 51 | the original version of the Work and any modifications or additions 52 | to that Work or Derivative Works thereof, that is intentionally 53 | submitted to Licensor for inclusion in the Work by the copyright owner 54 | or by an individual or Legal Entity authorized to submit on behalf of 55 | the copyright owner. For the purposes of this definition, "submitted" 56 | means any form of electronic, verbal, or written communication sent 57 | to the Licensor or its representatives, including but not limited to 58 | communication on electronic mailing lists, source code control systems, 59 | and issue tracking systems that are managed by, or on behalf of, the 60 | Licensor for the purpose of discussing and improving the Work, but 61 | excluding communication that is conspicuously marked or otherwise 62 | designated in writing by the copyright owner as "Not a Contribution." 63 | 64 | "Contributor" shall mean Licensor and any individual or Legal Entity 65 | on behalf of whom a Contribution has been received by Licensor and 66 | subsequently incorporated within the Work. 67 | 68 | 2. Grant of Copyright License. Subject to the terms and conditions of 69 | this License, each Contributor hereby grants to You a perpetual, 70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 71 | copyright license to reproduce, prepare Derivative Works of, 72 | publicly display, publicly perform, sublicense, and distribute the 73 | Work and such Derivative Works in Source or Object form. 74 | 75 | 3. Grant of Patent License. Subject to the terms and conditions of 76 | this License, each Contributor hereby grants to You a perpetual, 77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 78 | (except as stated in this section) patent license to make, have made, 79 | use, offer to sell, sell, import, and otherwise transfer the Work, 80 | where such license applies only to those patent claims licensable 81 | by such Contributor that are necessarily infringed by their 82 | Contribution(s) alone or by combination of their Contribution(s) 83 | with the Work to which such Contribution(s) was submitted. If You 84 | institute patent litigation against any entity (including a 85 | cross-claim or counterclaim in a lawsuit) alleging that the Work 86 | or a Contribution incorporated within the Work constitutes direct 87 | or contributory patent infringement, then any patent licenses 88 | granted to You under this License for that Work shall terminate 89 | as of the date such litigation is filed. 90 | 91 | 4. Redistribution. You may reproduce and distribute copies of the 92 | Work or Derivative Works thereof in any medium, with or without 93 | modifications, and in Source or Object form, provided that You 94 | meet the following conditions: 95 | 96 | (a) You must give any other recipients of the Work or 97 | Derivative Works a copy of this License; and 98 | 99 | (b) You must cause any modified files to carry prominent notices 100 | stating that You changed the files; and 101 | 102 | (c) You must retain, in the Source form of any Derivative Works 103 | that You distribute, all copyright, patent, trademark, and 104 | attribution notices from the Source form of the Work, 105 | excluding those notices that do not pertain to any part of 106 | the Derivative Works; and 107 | 108 | (d) If the Work includes a "NOTICE" text file as part of its 109 | distribution, then any Derivative Works that You distribute must 110 | include a readable copy of the attribution notices contained 111 | within such NOTICE file, excluding those notices that do not 112 | pertain to any part of the Derivative Works, in at least one 113 | of the following places: within a NOTICE text file distributed 114 | as part of the Derivative Works; within the Source form or 115 | documentation, if provided along with the Derivative Works; or, 116 | within a display generated by the Derivative Works, if and 117 | wherever such third-party notices normally appear. The contents 118 | of the NOTICE file are for informational purposes only and 119 | do not modify the License. You may add Your own attribution 120 | notices within Derivative Works that You distribute, alongside 121 | or as an addendum to the NOTICE text from the Work, provided 122 | that such additional attribution notices cannot be construed 123 | as modifying the License. 124 | 125 | You may add Your own copyright statement to Your modifications and 126 | may provide additional or different license terms and conditions 127 | for use, reproduction, or distribution of Your modifications, or 128 | for any such Derivative Works as a whole, provided Your use, 129 | reproduction, and distribution of the Work otherwise complies with 130 | the conditions stated in this License. 131 | 132 | 5. Submission of Contributions. Unless You explicitly state otherwise, 133 | any Contribution intentionally submitted for inclusion in the Work 134 | by You to the Licensor shall be under the terms and conditions of 135 | this License, without any additional terms or conditions. 136 | Notwithstanding the above, nothing herein shall supersede or modify 137 | the terms of any separate license agreement you may have executed 138 | with Licensor regarding such Contributions. 139 | 140 | 6. Trademarks. This License does not grant permission to use the trade 141 | names, trademarks, service marks, or product names of the Licensor, 142 | except as required for reasonable and customary use in describing the 143 | origin of the Work and reproducing the content of the NOTICE file. 144 | 145 | 7. Disclaimer of Warranty. Unless required by applicable law or 146 | agreed to in writing, Licensor provides the Work (and each 147 | Contributor provides its Contributions) on an "AS IS" BASIS, 148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 149 | implied, including, without limitation, any warranties or conditions 150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 151 | PARTICULAR PURPOSE. You are solely responsible for determining the 152 | appropriateness of using or redistributing the Work and assume any 153 | risks associated with Your exercise of permissions under this License. 154 | 155 | 8. Limitation of Liability. In no event and under no legal theory, 156 | whether in tort (including negligence), contract, or otherwise, 157 | unless required by applicable law (such as deliberate and grossly 158 | negligent acts) or agreed to in writing, shall any Contributor be 159 | liable to You for damages, including any direct, indirect, special, 160 | incidental, or consequential damages of any character arising as a 161 | result of this License or out of the use or inability to use the 162 | Work (including but not limited to damages for loss of goodwill, 163 | work stoppage, computer failure or malfunction, or any and all 164 | other commercial damages or losses), even if such Contributor 165 | has been advised of the possibility of such damages. 166 | 167 | 9. Accepting Warranty or Additional Liability. While redistributing 168 | the Work or Derivative Works thereof, You may choose to offer, 169 | and charge a fee for, acceptance of support, warranty, indemnity, 170 | or other liability obligations and/or rights consistent with this 171 | License. However, in accepting such obligations, You may act only 172 | on Your own behalf and on Your sole responsibility, not on behalf 173 | of any other Contributor, and only if You agree to indemnify, 174 | defend, and hold each Contributor harmless for any liability 175 | incurred by, or claims asserted against, such Contributor by reason 176 | of your accepting any such warranty or additional liability. 177 | 178 | END OF TERMS AND CONDITIONS 179 | 180 | APPENDIX: How to apply the Apache License to your work. 181 | 182 | To apply the Apache License to your work, attach the following 183 | boilerplate notice, with the fields enclosed by brackets "[]" 184 | replaced with your own identifying information. (Don't include 185 | the brackets!) The text should be enclosed in the appropriate 186 | comment syntax for the file format. We also recommend that a 187 | file or class name and description of purpose be included on the 188 | same "printed page" as the copyright notice for easier 189 | identification within third-party archives. 190 | 191 | Copyright [yyyy] [name of copyright owner] 192 | 193 | Licensed under the Apache License, Version 2.0 (the "License"); 194 | you may not use this file except in compliance with the License. 195 | You may obtain a copy of the License at 196 | 197 | http://www.apache.org/licenses/LICENSE-2.0 198 | 199 | Unless required by applicable law or agreed to in writing, software 200 | distributed under the License is distributed on an "AS IS" BASIS, 201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 202 | See the License for the specific language governing permissions and 203 | limitations under the License. 204 | * For DeepSpeed see also this required NOTICE: 205 | # Copyright (c) Microsoft Corporation. 206 | # SPDX-License-Identifier: Apache-2.0 207 | 208 | # DeepSpeed Team 209 | 210 | ------ 211 | 212 | ** PyTorch; version 2.1.1 -- https://github.com/pytorch/pytorch 213 | From PyTorch: 214 | 215 | Copyright (c) 2016- Facebook, Inc (Adam Paszke) 216 | Copyright (c) 2014- Facebook, Inc (Soumith Chintala) 217 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) 218 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) 219 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) 220 | Copyright (c) 2011-2013 NYU (Clement Farabet) 221 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, 222 | Iain Melvin, Jason Weston) 223 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio) 224 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, 225 | Johnny Mariethoz) 226 | 227 | From Caffe2: 228 | 229 | Copyright (c) 2016-present, Facebook Inc. All rights reserved. 230 | 231 | All contributions by Facebook: 232 | Copyright (c) 2016 Facebook Inc. 233 | 234 | All contributions by Google: 235 | Copyright (c) 2015 Google Inc. 236 | All rights reserved. 237 | 238 | All contributions by Yangqing Jia: 239 | Copyright (c) 2015 Yangqing Jia 240 | All rights reserved. 241 | 242 | All contributions by Kakao Brain: 243 | Copyright 2019-2020 Kakao Brain 244 | 245 | All contributions from Caffe: 246 | Copyright(c) 2013, 2014, 2015, the respective contributors 247 | All rights reserved. 248 | 249 | All other contributions: 250 | Copyright(c) 2015, 2016 the respective contributors 251 | All rights reserved. 252 | 253 | Caffe2 uses a copyright model similar to Caffe: each contributor holds 254 | copyright over their contributions to Caffe2. The project versioning records 255 | all such contribution and copyright details. If a contributor wants to further 256 | mark their specific copyright on a particular contribution, they should 257 | indicate their copyright solely in the commit message of the change when it is 258 | committed. 259 | 260 | All rights reserved. 261 | 262 | Redistribution and use in source and binary forms, with or without 263 | modification, are permitted provided that the following conditions are met: 264 | 265 | 1. Redistributions of source code must retain the above copyright 266 | notice, this list of conditions and the following disclaimer. 267 | 268 | 2. Redistributions in binary form must reproduce the above copyright 269 | notice, this list of conditions and the following disclaimer in the 270 | documentation and/or other materials provided with the distribution. 271 | 272 | 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories 273 | America 274 | and IDIAP Research Institute nor the names of its contributors may be 275 | used to endorse or promote products derived from this software without 276 | specific prior written permission. 277 | 278 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 279 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 280 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 281 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 282 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 283 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 284 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 285 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 286 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 287 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 288 | POSSIBILITY OF SUCH DAMAGE. 289 | -------------------------------------------------------------------------------- /docs/images/overview_horizontal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/optimizing-multitask-training-through-dynamic-pipelines/d8666ad4e112d0241cb7fc025ce9bea72c32b924/docs/images/overview_horizontal.png -------------------------------------------------------------------------------- /dynapipe/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from .model import DynaPipeCluster, DynaPipeMicrobatch 5 | from .utils.memory_utils import TransformerMemoryModel 6 | 7 | __all__ = [ 8 | "TransformerMemoryModel", 9 | "DynaPipeMicrobatch", 10 | "DynaPipeCluster", 11 | ] 12 | -------------------------------------------------------------------------------- /dynapipe/data_opt/Makefile: -------------------------------------------------------------------------------- 1 | CXXFLAGS += -O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color 2 | OS := $(shell uname) 3 | ifeq ($(OS),Darwin) 4 | CXXFLAGS += -undefined dynamic_lookup 5 | endif 6 | 7 | CPPFLAGS += $(shell python3 -m pybind11 --includes) 8 | LIBNAME = dp_helper 9 | LIBEXT = $(shell python3-config --extension-suffix) 10 | 11 | 12 | default: $(LIBNAME)$(LIBEXT) 13 | 14 | %$(LIBEXT): %.cpp 15 | $(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@ 16 | -------------------------------------------------------------------------------- /dynapipe/data_opt/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | -------------------------------------------------------------------------------- /dynapipe/memory_opt/allocation_simulator.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from dataclasses import dataclass, field 5 | 6 | from sortedcontainers import SortedList 7 | 8 | 9 | class AllocatorSimulator: 10 | def __init__(self, max_memory_mbytes) -> None: 11 | self.max_memory = max_memory_mbytes * 1e6 12 | 13 | def malloc(self, size): 14 | pass 15 | 16 | def free(self, ptr): 17 | pass 18 | 19 | 20 | @dataclass 21 | class TorchCachingAllocatorConfig: 22 | max_split_size: float = float("inf") 23 | garbage_collection_threshold: float = 0.0 24 | kMinBlockSize: int = 512 25 | # kSmallSize: int = 1048576 26 | kSmallSize: int = 0 27 | kSmallBuffer: int = 2097152 28 | kLargeBuffer: int = 20971520 29 | kMinLargeAlloc: int = 10485760 30 | kRoundLarge: int = 2097152 31 | kRoundUpPowerOfTwoIntervals: int = 16 32 | 33 | 34 | @dataclass 35 | class TorchBlock: 36 | stream: int 37 | size: int 38 | ptr: int = -1 39 | pool: "TorchBlockPool" = field(default=None, repr=False) 40 | allocated: bool = False 41 | requested_size: int = -1 42 | prev: "TorchBlock" = None 43 | next: "TorchBlock" = None 44 | 45 | @classmethod 46 | def compare_key(cls, x: "TorchBlock"): 47 | return x.stream, x.size, x.ptr 48 | 49 | def is_split(self): 50 | return self.prev is not None or self.next is not None 51 | 52 | 53 | @dataclass 54 | class TorchBlockPool: 55 | blocks: SortedList = field( 56 | default_factory=lambda: SortedList(key=TorchBlock.compare_key) 57 | ) 58 | is_small: bool = False 59 | 60 | 61 | @dataclass 62 | class TorchAllocParams: 63 | size: int 64 | stream: int 65 | pool: TorchBlockPool 66 | alloc_size: int 67 | 68 | 69 | class TorchCachingAllocatorSimulator(AllocatorSimulator): 70 | def __init__( 71 | self, 72 | max_memory_mbytes, 73 | allocator_config: TorchCachingAllocatorConfig = None, 74 | ) -> None: 75 | super().__init__(max_memory_mbytes) 76 | if allocator_config is None: 77 | allocator_config = TorchCachingAllocatorConfig() 78 | self.config = allocator_config 79 | self.allocated_bytes = 0 80 | self.peak_allocated_bytes = 0 81 | self.allocated_segments = 0 82 | self.peak_allocated_segments = 0 83 | self.backend_allocated_bytes = 0 84 | self.peak_backend_allocated_bytes = 0 85 | self.n_backend_mallocs = 0 86 | self.n_backend_frees = 0 87 | self.n_alloc_large_pool = 0 88 | self.n_alloc_small_pool = 0 89 | # we only increase cuda_ptr 90 | # this means we assume cuda malloc has zero fragmentation 91 | self._backend_ptr = 0 92 | self._backend_ptr_to_size = {} 93 | self._large_pool = TorchBlockPool(is_small=False) 94 | self._small_pool = TorchBlockPool(is_small=True) 95 | self._timestep = 0 96 | 97 | def _round_size(self, size: int): 98 | mb = self.config.kMinBlockSize 99 | if size < mb: 100 | return mb 101 | else: 102 | return mb * ((size + mb - 1) // mb) 103 | 104 | def _get_pool(self, size: int): 105 | if size <= self.config.kSmallSize: 106 | return self._small_pool 107 | else: 108 | return self._large_pool 109 | 110 | def _get_allocation_size(self, size: int): 111 | if size <= self.config.kSmallSize: 112 | return self.config.kSmallBuffer 113 | elif size < self.config.kMinLargeAlloc: 114 | return self.config.kLargeBuffer 115 | else: 116 | rl = self.config.kRoundLarge 117 | return rl * ((size + rl - 1) // rl) 118 | 119 | def _get_free_block(self, params: TorchAllocParams): 120 | pool = params.pool 121 | stream = params.stream 122 | size = params.size 123 | block_index = pool.blocks.bisect_left( 124 | TorchBlock(stream=stream, size=size) 125 | ) 126 | if ( 127 | block_index == len(pool.blocks) 128 | or pool.blocks[block_index].stream != stream 129 | ): 130 | return None 131 | block: TorchBlock = pool.blocks[block_index] 132 | # Do not return an oversized block for a large request 133 | if (size < self.config.max_split_size) and ( 134 | block.size >= self.config.max_split_size 135 | ): 136 | return None 137 | # Allow oversized block size to be rounded up but within a limit 138 | if (size >= self.config.max_split_size) and ( 139 | block.size >= size + self.config.kLargeBuffer 140 | ): 141 | return None 142 | pool.blocks.remove(block) 143 | return block 144 | 145 | def _alloc_block(self, params: TorchAllocParams): 146 | size = params.alloc_size 147 | backend_ptr = self.backend_malloc(size) 148 | if backend_ptr == -1: 149 | return None 150 | return TorchBlock( 151 | stream=params.stream, size=size, ptr=backend_ptr, pool=params.pool 152 | ) 153 | 154 | def _release_block(self, block: TorchBlock): 155 | if block.ptr != -1: 156 | self.backend_free(block.ptr) 157 | pool = block.pool 158 | pool.blocks.remove(block) 159 | 160 | def _release_available_cached_blocks(self, params: TorchAllocParams): 161 | if self.config.max_split_size == float("inf"): 162 | return False 163 | pool = params.pool 164 | key_block = TorchBlock(stream=params.stream, size=params.size) 165 | if key_block.size < self.config.max_split_size: 166 | key_block.size = self.config.max_split_size 167 | key_index = pool.blocks.bisect_left(key_block) 168 | if ( 169 | key_index == len(pool.blocks) 170 | or pool.blocks[key_index].stream != params.stream 171 | ): 172 | # No single block is large enough; free multiple oversize blocks, 173 | # starting with the largest 174 | if key_index == 0: 175 | return False 176 | total_released = 0 177 | key_index -= 1 178 | while ( 179 | total_released < key_block.size 180 | and pool.blocks[key_index].size >= self.config.max_split_size 181 | and pool.blocks[key_index].stream == params.stream 182 | ): 183 | cur_block = pool.blocks[key_index] 184 | total_released += cur_block.size 185 | if key_index != 0: 186 | key_index -= 1 187 | self._release_block(cur_block) 188 | else: 189 | self._release_block(cur_block) 190 | break 191 | if total_released < key_block.size: 192 | return False 193 | else: 194 | self._release_block(pool.blocks[key_index]) 195 | return True 196 | 197 | def _release_blocks(self, pool: TorchBlockPool): 198 | for block in pool.blocks.copy(): 199 | if block.prev is None and block.next is None: 200 | self._release_block(block) 201 | 202 | def _release_cached_blocks(self): 203 | self._release_blocks(self._large_pool) 204 | self._release_blocks(self._small_pool) 205 | 206 | def _should_split(self, block: TorchBlock, size: int): 207 | remaining = block.size - size 208 | if block.pool.is_small: 209 | return remaining >= self.config.kMinBlockSize 210 | else: 211 | return ( 212 | size < self.config.max_split_size 213 | and remaining > self.config.kSmallSize 214 | ) 215 | 216 | def _alloc_found_block( 217 | self, 218 | block: TorchBlock, 219 | params: TorchAllocParams, 220 | orig_size: int, 221 | split_remainder: bool, 222 | ): 223 | size = params.size 224 | pool = params.pool 225 | stream = params.stream 226 | assert block is not None and block.ptr != -1 227 | if split_remainder: 228 | remaining = block 229 | block = TorchBlock( 230 | stream=stream, size=size, ptr=block.ptr, pool=pool 231 | ) 232 | block.prev = remaining.prev 233 | if block.prev: 234 | block.prev.next = block 235 | block.next = remaining 236 | remaining.prev = block 237 | remaining.ptr += size 238 | remaining.size -= size 239 | pool.blocks.add(remaining) 240 | block.allocated = True 241 | block.requested_size = orig_size 242 | assert block.size > 0 243 | self.allocated_bytes += block.size 244 | self.peak_allocated_bytes = max( 245 | self.peak_allocated_bytes, self.allocated_bytes 246 | ) 247 | return block 248 | 249 | # combine previously split blocks. returns the size of the subsumed block, 250 | # or 0 on failure. 251 | def _try_merge_blocks( 252 | self, dst: TorchBlock, src: TorchBlock, pool: TorchBlockPool 253 | ): 254 | if not src or src.allocated: 255 | return 0 256 | assert src.is_split() and dst.is_split() 257 | if dst.prev == src: # [src, dst] 258 | dst.ptr = src.ptr 259 | dst.prev = src.prev 260 | if dst.prev: 261 | dst.prev.next = dst 262 | else: # [dst, src] 263 | dst.next = src.next 264 | if dst.next: 265 | dst.next.prev = dst 266 | subsumed_size = src.size 267 | dst.size += subsumed_size 268 | pool.blocks.remove(src) 269 | return subsumed_size 270 | 271 | def _free_block(self, block: TorchBlock): 272 | assert not block.allocated 273 | pool = block.pool 274 | 275 | merge_candidates = (block.prev, block.next) 276 | for candidate in merge_candidates: 277 | self._try_merge_blocks(block, candidate, pool) 278 | pool.blocks.add(block) 279 | 280 | def backend_malloc(self, size): 281 | if self.backend_allocated_bytes + size > self.max_memory: 282 | return -1 283 | self._backend_ptr += size 284 | self.backend_allocated_bytes += size 285 | self._backend_ptr_to_size[self._backend_ptr] = size 286 | self.allocated_segments += 1 287 | self.peak_allocated_segments = max( 288 | self.peak_allocated_segments, self.allocated_segments 289 | ) 290 | self.peak_backend_allocated_bytes = max( 291 | self.peak_backend_allocated_bytes, self.backend_allocated_bytes 292 | ) 293 | self.n_backend_mallocs += 1 294 | return self._backend_ptr 295 | 296 | def backend_free(self, ptr): 297 | assert ptr in self._backend_ptr_to_size 298 | size = self._backend_ptr_to_size[ptr] 299 | self.backend_allocated_bytes -= size 300 | self.allocated_segments -= 1 301 | self.n_backend_frees += 1 302 | 303 | def malloc(self, size, stream=0): 304 | size = self._round_size(size) 305 | pool = self._get_pool(size) 306 | if pool.is_small: 307 | self.n_alloc_small_pool += 1 308 | else: 309 | self.n_alloc_large_pool += 1 310 | alloc_size = self._get_allocation_size(size) 311 | param = TorchAllocParams( 312 | size=size, stream=stream, pool=pool, alloc_size=alloc_size 313 | ) 314 | block = self._get_free_block(param) 315 | # we don't simulate free block callbacks for now 316 | if block is None: 317 | # attempt allocation 318 | block = self._alloc_block(param) 319 | if block is None: 320 | self._release_available_cached_blocks(param) 321 | block = self._alloc_block(param) 322 | if block is None: 323 | self._release_cached_blocks() 324 | block = self._alloc_block(param) 325 | if block is None: 326 | # we are out of memory 327 | raise RuntimeError("Out of Memory") 328 | assert block is not None 329 | should_split_remainder = self._should_split(block, param.size) 330 | self._timestep += 1 331 | return self._alloc_found_block( 332 | block, param, size, should_split_remainder 333 | ) 334 | 335 | def free(self, block: TorchBlock): 336 | block.allocated = False 337 | orig_size = block.size 338 | self._free_block(block) 339 | self.allocated_bytes -= orig_size 340 | self._timestep += 1 341 | 342 | def clear_cache(self): 343 | self._release_cached_blocks() 344 | 345 | def reset_peak_stats(self): 346 | self.peak_allocated_bytes = 0 347 | self.peak_allocated_segments = 0 348 | self.peak_backend_allocated_bytes = 0 349 | -------------------------------------------------------------------------------- /dynapipe/memory_opt/cuda_caching_allocator.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import ctypes 5 | import glob 6 | import os 7 | 8 | import torch 9 | 10 | _allocator = None 11 | 12 | 13 | class DynaPipeCachingAllocator: 14 | # wrapper for the C++ allocator 15 | def __init__(self, dll): 16 | self.dll = dll 17 | self._c_peak_reserved_cuda_memory = self.get_func( 18 | "dynapipe_get_peak_reserved_cuda_memory" 19 | ) 20 | self._c_peak_reserved_cuda_memory.restype = ctypes.c_int64 21 | self._c_peak_allocated_cuda_memory = self.get_func( 22 | "dynapipe_get_peak_allocated_cuda_memory" 23 | ) 24 | self._c_peak_allocated_cuda_memory.restype = ctypes.c_int64 25 | self._c_peak_requested_cuda_memory = self.get_func( 26 | "dynapipe_get_peak_requested_cuda_memory" 27 | ) 28 | self._c_peak_requested_cuda_memory.restype = ctypes.c_int64 29 | self._c_current_reserved_cuda_memory = self.get_func( 30 | "dynapipe_get_current_reserved_cuda_memory" 31 | ) 32 | self._c_current_reserved_cuda_memory.restype = ctypes.c_int64 33 | self._c_current_allocated_cuda_memory = self.get_func( 34 | "dynapipe_get_current_allocated_cuda_memory" 35 | ) 36 | self._c_current_allocated_cuda_memory.restype = ctypes.c_int64 37 | self._c_current_requested_cuda_memory = self.get_func( 38 | "dynapipe_get_current_requested_cuda_memory" 39 | ) 40 | self._c_current_requested_cuda_memory.restype = ctypes.c_int64 41 | self._c_get_memory_snapshot = self.get_func( 42 | "dynapipe_get_memory_snapshot" 43 | ) 44 | self._c_get_memory_snapshot.argtypes = [ 45 | ctypes.POINTER(ctypes.c_size_t) 46 | ] 47 | self._c_get_memory_snapshot.restype = ctypes.POINTER(ctypes.c_char) 48 | 49 | def get_func(self, func_name): 50 | return getattr(self.dll, func_name) 51 | 52 | def get_func_ptr(self, func_name): 53 | return ctypes.cast(getattr(self.dll, func_name), ctypes.c_void_p).value 54 | 55 | def num_cuda_mallocs(self): 56 | return self.get_func("dynapipe_get_num_cuda_mallocs")() 57 | 58 | def num_cuda_frees(self): 59 | return self.get_func("dynapipe_get_num_cuda_frees")() 60 | 61 | def reset_peak_stats(self): 62 | return self.get_func("dynapipe_reset_peak_stats")() 63 | 64 | def reset_accumulated_stats(self): 65 | return self.get_func("dynapipe_reset_accumulated_stats")() 66 | 67 | def peak_reserved_cuda_memory(self): 68 | return self._c_peak_reserved_cuda_memory() 69 | 70 | def peak_allocated_cuda_memory(self): 71 | return self._c_peak_allocated_cuda_memory() 72 | 73 | def peak_requested_cuda_memory(self): 74 | return self._c_peak_requested_cuda_memory() 75 | 76 | def current_reserved_cuda_memory(self): 77 | return self._c_current_reserved_cuda_memory() 78 | 79 | def current_allocated_cuda_memory(self): 80 | return self._c_current_allocated_cuda_memory() 81 | 82 | def current_requested_cuda_memory(self): 83 | return self._c_current_requested_cuda_memory() 84 | 85 | def get_memory_snapshot(self): 86 | buf_size = ctypes.c_size_t() 87 | c_ptr = self._c_get_memory_snapshot(ctypes.byref(buf_size)) 88 | return ctypes.string_at(c_ptr, buf_size.value) 89 | 90 | 91 | def find_library(): 92 | """Find the compiled library.""" 93 | # Find the library path. 94 | library_path = os.path.join( 95 | os.path.dirname(__file__), "build", "lib*", "dynapipe_cuda_allocator.*" 96 | ) 97 | library_path = glob.glob(library_path)[0] 98 | return library_path 99 | 100 | 101 | def get_allocator(): 102 | """Get the custom allocator wrapper.""" 103 | global _allocator 104 | assert _allocator is not None, "Allocator not overriden" 105 | return _allocator 106 | 107 | 108 | def _remove_args(func): 109 | """Remove the arguments from the function.""" 110 | 111 | def wrapper(*args, **kwargs): 112 | return func() 113 | 114 | return wrapper 115 | 116 | 117 | def override_allocator(): 118 | """Override the default PyTorch allocator with the custom allocator.""" 119 | global _allocator 120 | if _allocator is not None: 121 | return 122 | # Load the library. 123 | library_path = find_library() 124 | dll = ctypes.CDLL(library_path) 125 | allocator_wrapper = DynaPipeCachingAllocator(dll) 126 | new_alloc = torch.cuda.memory.CUDAPluggableAllocator( 127 | library_path, "dynapipe_malloc", "dynapipe_free" 128 | ) 129 | new_alloc._allocator.set_memory_fraction_fn( 130 | allocator_wrapper.get_func_ptr("dynapipe_set_memory_fraction") 131 | ) 132 | new_alloc._allocator.set_release_pool( 133 | allocator_wrapper.get_func_ptr("dynapipe_release_pool") 134 | ) 135 | new_alloc._allocator.set_reset_fn( 136 | allocator_wrapper.get_func_ptr("dynapipe_reset") 137 | ) 138 | torch.cuda.memory.change_current_allocator(new_alloc) 139 | _allocator = allocator_wrapper 140 | # override torch's get memory stats function to avoid errors 141 | # Note: all args (e.g. device_index) are removed, all stats are 142 | # only for current device. Use with caution. 143 | torch.cuda.memory_allocated = _remove_args( 144 | _allocator.current_allocated_cuda_memory 145 | ) 146 | torch.cuda.max_memory_allocated = _remove_args( 147 | _allocator.peak_allocated_cuda_memory 148 | ) 149 | torch.cuda.memory_reserved = _remove_args( 150 | _allocator.current_reserved_cuda_memory 151 | ) 152 | torch.cuda.max_memory_reserved = _remove_args( 153 | _allocator.peak_reserved_cuda_memory 154 | ) 155 | torch.cuda.reset_accumulated_memory_stats = _remove_args( 156 | _allocator.reset_accumulated_stats 157 | ) 158 | torch.cuda.reset_peak_memory_stats = _remove_args( 159 | _allocator.reset_peak_stats 160 | ) 161 | -------------------------------------------------------------------------------- /dynapipe/memory_opt/host_caching_allocator.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import collections 5 | import queue 6 | from typing import Tuple, Union 7 | 8 | import torch 9 | from torch._utils import ExceptionWrapper 10 | from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL 11 | 12 | from dynapipe.memory_opt.allocation_simulator import ( 13 | TorchCachingAllocatorSimulator, 14 | ) 15 | 16 | 17 | class HostCachingAllocatorPtr: 18 | # This is used as a index to find and access the allocated torch 19 | # tensors. This acts as a normal pointer which supports addition. 20 | def __init__(self, tensor_hash: int, offset): 21 | if not isinstance(tensor_hash, int): 22 | raise RuntimeError("tensor_hash must be an integer") 23 | self.tensor_hash = tensor_hash 24 | self.offset = offset 25 | 26 | def __hash__(self) -> int: 27 | return hash((self.tensor_hash, self.offset)) 28 | 29 | def __add__(self, other: Union["HostCachingAllocatorPtr", int]): 30 | if not isinstance(other, (HostCachingAllocatorPtr, int)): 31 | raise RuntimeError( 32 | "Cannot add a HostCachingAllocatorPtr and a non integer" 33 | ) 34 | if isinstance(other, int): 35 | return HostCachingAllocatorPtr( 36 | self.tensor_hash, self.offset + other 37 | ) 38 | if self.tensor_hash != hash(other.tensor_hash): 39 | raise RuntimeError("Cannot add two different tensors") 40 | return HostCachingAllocatorPtr(self.tensor_hash, self.offset + other) 41 | 42 | def __radd__(self, other: Union["HostCachingAllocatorPtr", int]): 43 | return self.__add__(other) 44 | 45 | def __eq__(self, other: "HostCachingAllocatorPtr"): 46 | if not isinstance(other, HostCachingAllocatorPtr): 47 | return False 48 | return ( 49 | self.tensor_hash == other.tensor_hash 50 | and self.offset == other.offset 51 | ) 52 | 53 | def __lt__(self, other: "HostCachingAllocatorPtr"): 54 | if not isinstance(other, HostCachingAllocatorPtr): 55 | # it may be used to compare with nullptr (-1) 56 | # we always want to return false in this case 57 | return False 58 | if self.tensor_hash == other.tensor_hash: 59 | return self.offset < other.offset 60 | return self.tensor_hash < other.tensor_hash 61 | 62 | def __gt__(self, other: "HostCachingAllocatorPtr"): 63 | if self.tensor_hash == other.tensor_hash: 64 | return self.offset > other.offset 65 | return self.tensor_hash > other.tensor_hash 66 | 67 | def __le__(self, other: "HostCachingAllocatorPtr"): 68 | if self.tensor_hash == other.tensor_hash: 69 | return self.offset <= other.offset 70 | return self.tensor_hash <= other.tensor_hash 71 | 72 | def __ge__(self, other: "HostCachingAllocatorPtr"): 73 | if self.tensor_hash == other.tensor_hash: 74 | return self.offset >= other.offset 75 | return self.tensor_hash >= other 76 | 77 | def __ne__(self, other: "HostCachingAllocatorPtr"): 78 | if not isinstance(other, HostCachingAllocatorPtr): 79 | return True 80 | return ( 81 | self.tensor_hash != other.tensor_hash 82 | or self.offset != other.offset 83 | ) 84 | 85 | 86 | class DestructionCallback: 87 | # This is attached to the tensor as an attribute. 88 | # The destruction callback is called when the tensor is deleted. 89 | def __init__(self, callback): 90 | self.callback = callback 91 | 92 | def __del__(self): 93 | self.callback() 94 | 95 | 96 | class HostCachingAllocator(TorchCachingAllocatorSimulator): 97 | def __init__(self, allocator_config=None) -> None: 98 | super().__init__(float("inf"), allocator_config) 99 | self._segment_map = {} 100 | self._block_map = {} 101 | 102 | # override 103 | def backend_malloc(self, size): 104 | # we don't need to manually pose a limit on the memory 105 | # _backend_ptr is not used 106 | tensor = torch.empty( 107 | size, dtype=torch.uint8, device="cpu", pin_memory=True 108 | ) 109 | tensor_hash = hash(tensor) 110 | self._segment_map[tensor_hash] = tensor 111 | tensor_ptr = HostCachingAllocatorPtr(tensor_hash, 0) 112 | self.backend_allocated_bytes += size 113 | self._backend_ptr_to_size[tensor_ptr] = size 114 | self.allocated_segments += 1 115 | self.peak_allocated_segments = max( 116 | self.peak_allocated_segments, self.allocated_segments 117 | ) 118 | self.peak_backend_allocated_bytes = max( 119 | self.peak_backend_allocated_bytes, self.backend_allocated_bytes 120 | ) 121 | self.n_backend_mallocs += 1 122 | return tensor_ptr 123 | 124 | def backend_free(self, ptr): 125 | if not isinstance(ptr, HostCachingAllocatorPtr): 126 | raise RuntimeError("ptr must be a HostCachingAllocatorPtr") 127 | if ptr not in self._backend_ptr_to_size: 128 | raise RuntimeError("ptr is not a valid pointer") 129 | tensor_hash = ptr.tensor_hash 130 | if tensor_hash not in self._segment_map: 131 | raise RuntimeError("tensor_hash do not map to a allocated tensor") 132 | del self._segment_map[tensor_hash] 133 | size = self._backend_ptr_to_size[ptr] 134 | del self._backend_ptr_to_size[ptr] 135 | self.backend_allocated_bytes -= size 136 | self.allocated_segments -= 1 137 | self.n_backend_frees += 1 138 | 139 | # wrapper around the original malloc, returning a pinned tensor 140 | def malloc( 141 | self, shape: Tuple[int, ...], dtype: torch.dtype 142 | ) -> torch.Tensor: 143 | size = 1 144 | for dim in shape: 145 | size *= dim 146 | size *= torch._utils._element_size(dtype) 147 | block = super().malloc(size) 148 | ptr: HostCachingAllocatorPtr = block.ptr 149 | tensor_hash = ptr.tensor_hash 150 | if tensor_hash not in self._segment_map: 151 | raise RuntimeError("tensor_hash do not map to a allocated tensor") 152 | tensor = self._segment_map[tensor_hash] 153 | self._block_map[ptr] = block 154 | tensor_view = ( 155 | tensor[ptr.offset : ptr.offset + size].view(dtype).view(shape) 156 | ) 157 | 158 | def destructor(): 159 | self.free(ptr) 160 | 161 | tensor_view._destructor = DestructionCallback(destructor) 162 | return tensor_view 163 | 164 | def free(self, ptr: HostCachingAllocatorPtr): 165 | if ptr not in self._block_map: 166 | raise RuntimeError("ptr is not a valid pointer") 167 | block = self._block_map[ptr] 168 | del self._block_map[ptr] 169 | super().free(block) 170 | 171 | 172 | # copied from torch.utils.data._utils.pin_memory 173 | def _pin_memory_loop(in_queue, out_queue, device_id, done_event, device): 174 | # This setting is thread local, and prevents the copy in pin_memory from 175 | # consuming all CPU cores. 176 | torch.set_num_threads(1) 177 | 178 | if device == "cuda": 179 | torch.cuda.set_device(device_id) 180 | elif device == "xpu": 181 | torch.xpu.set_device(device_id) # type: ignore[attr-defined] 182 | 183 | hca = HostCachingAllocator() 184 | # pre-allocate a large chunk of memory 185 | x = hca.malloc((1024 * 1024 * 1024,), torch.uint8) 186 | del x 187 | 188 | def do_one_step(): 189 | try: 190 | r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) 191 | except queue.Empty: 192 | return 193 | idx, data = r 194 | if not done_event.is_set() and not isinstance(data, ExceptionWrapper): 195 | try: 196 | data = _pin_memory(hca, data, device) 197 | except Exception: 198 | data = ExceptionWrapper( 199 | where="in pin memory thread for device {}".format( 200 | device_id 201 | ) 202 | ) 203 | r = (idx, data) 204 | while not done_event.is_set(): 205 | try: 206 | out_queue.put(r, timeout=MP_STATUS_CHECK_INTERVAL) 207 | break 208 | except queue.Full: 209 | continue 210 | 211 | # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details 212 | # on the logic of this function. 213 | while not done_event.is_set(): 214 | # Make sure that we don't preserve any object from one iteration 215 | # to the next 216 | do_one_step() 217 | 218 | 219 | # monkey patches 220 | # copied from torch.utils.data._utils.pin_memory 221 | def _pin_memory(hca: HostCachingAllocator, data, device=None): 222 | if isinstance(data, torch.Tensor): 223 | pinned_tensor = hca.malloc(data.shape, data.dtype) 224 | pinned_tensor.copy_(data) 225 | return pinned_tensor 226 | elif isinstance(data, str): 227 | return data 228 | elif isinstance(data, collections.abc.Mapping): 229 | try: 230 | return type(data)( 231 | { 232 | k: _pin_memory(hca, sample, device) 233 | for k, sample in data.items() 234 | } 235 | ) 236 | except TypeError: 237 | # The mapping type may not support `__init__(iterable)`. 238 | return { 239 | k: _pin_memory(hca, sample, device) 240 | for k, sample in data.items() 241 | } 242 | elif isinstance(data, tuple) and hasattr(data, "_fields"): # namedtuple 243 | return type(data)( 244 | *(_pin_memory(hca, sample, device) for sample in data) 245 | ) 246 | elif isinstance(data, tuple): 247 | return [ 248 | _pin_memory(hca, sample, device) for sample in data 249 | ] # Backwards compatibility. 250 | elif isinstance(data, collections.abc.Sequence): 251 | try: 252 | return type(data)( 253 | [_pin_memory(hca, sample, device) for sample in data] 254 | ) 255 | except TypeError: 256 | # The sequence type may not support `__init__(iterable)` 257 | # (e.g., `range`). 258 | return [_pin_memory(hca, sample, device) for sample in data] 259 | elif hasattr(data, "pin_memory"): 260 | raise RuntimeError("Custom pin_memory function not supported.") 261 | else: 262 | return data 263 | 264 | 265 | def apply_monkey_patch(): 266 | # monkey patch the pin_memory function to use our caching HCA 267 | torch.utils.data._utils.pin_memory._pin_memory_loop = _pin_memory_loop 268 | -------------------------------------------------------------------------------- /dynapipe/memory_opt/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from setuptools import setup 5 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 6 | 7 | setup( 8 | name="dynapipe_cuda_allocator", 9 | ext_modules=[ 10 | CUDAExtension( 11 | name="dynapipe_cuda_allocator", 12 | sources=["cuda_caching_allocator.cpp"], 13 | extra_compile_args={"cxx": ["-O3"], "nvcc": ["-O3"]}, 14 | ) 15 | ], 16 | cmdclass={"build_ext": BuildExtension}, 17 | ) 18 | -------------------------------------------------------------------------------- /dynapipe/memory_opt/traceback.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * Modifications Copyright (c) Facebook, Inc. 5 | * See: https://github.com/pytorch/pytorch/blob/main/torch/csrc/profiler/combined_traceback.h 6 | * https://github.com/pytorch/pytorch/blob/main/torch/csrc/profiler/unwind/unwind.h 7 | */ 8 | 9 | #pragma once 10 | #include 11 | #include 12 | #include 13 | 14 | // copied from torch/csrc/profiler/combined_traceback.h 15 | // and torch/csrc/profiler/unwind/unwind.h 16 | 17 | namespace torch { 18 | namespace unwind { 19 | // gather current stack, relatively fast. 20 | // gets faster once the cache of program counter locations is warm. 21 | TORCH_API std::vector unwind(); 22 | 23 | struct Frame { 24 | std::string filename; 25 | std::string funcname; 26 | uint64_t lineno; 27 | }; 28 | 29 | // note: symbolize is really slow 30 | // it will launch an addr2line process that has to parse dwarf 31 | // information from the libraries that frames point into. 32 | // Callers should first batch up all the unique void* pointers 33 | // across a number of unwind states and make a single call to 34 | // symbolize. 35 | TORCH_API std::vector symbolize(const std::vector& frames); 36 | 37 | struct Stats { 38 | size_t hits = 0; 39 | size_t misses = 0; 40 | size_t unsupported = 0; 41 | size_t resets = 0; 42 | }; 43 | Stats stats(); 44 | 45 | } // namespace unwind 46 | } // namespace torch 47 | 48 | #include 49 | namespace torch { 50 | 51 | // struct that holds the result of symbolizing multiple tracebacks 52 | // each traceback is a list of indices into all_frames 53 | // (lots of Frames get duplicated across traces) 54 | struct TORCH_API SymbolizedTracebacks { 55 | std::vector all_frames; 56 | // index into all_frames, so that 57 | // it is possible to dedupe frame objects in 58 | // construction of python objects 59 | std::vector> tracebacks; 60 | }; 61 | 62 | struct TORCH_API CapturedTraceback : public c10::GatheredContext { 63 | struct PyFrame { 64 | void* code; // PyCodeObject*, but python headers not present 65 | int lasti; 66 | }; 67 | 68 | static std::shared_ptr gather( 69 | bool python, 70 | bool script, 71 | bool cpp); 72 | CapturedTraceback() = default; 73 | CapturedTraceback(const CapturedTraceback&) = delete; 74 | CapturedTraceback& operator=(const CapturedTraceback&) = delete; 75 | ~CapturedTraceback(); 76 | struct Python { 77 | virtual std::vector gather() = 0; 78 | virtual void release(std::vector& frames) = 0; 79 | virtual void appendSymbolized( 80 | const std::vector& to_symbolize, 81 | SymbolizedTracebacks& st) = 0; 82 | virtual ~Python() = default; 83 | Python* next_ = nullptr; 84 | }; 85 | // called once by each python interpreter to 86 | // register python stack recording functionality 87 | // p cannot be deleted once added. 88 | static void addPythonUnwinder(Python* p); 89 | 90 | private: 91 | std::vector frames_; 92 | std::vector cpp_frames_; 93 | std::vector script_frames_; 94 | friend TORCH_API SymbolizedTracebacks 95 | symbolize(const std::vector& to_symbolize); 96 | 97 | // non-owning reference to one of the immortal Python* objects 98 | // registered above. 99 | Python* python_ = nullptr; 100 | }; 101 | 102 | TORCH_API SymbolizedTracebacks 103 | symbolize(const std::vector& to_symbolize); 104 | 105 | } // namespace torch 106 | -------------------------------------------------------------------------------- /dynapipe/memory_opt/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import torch 5 | 6 | from dynapipe.memory_opt.cuda_caching_allocator import get_allocator 7 | 8 | 9 | def reserve_full_memory(custom_allocator=True): 10 | torch.cuda.synchronize() 11 | torch.cuda.empty_cache() 12 | total_memory = torch.cuda.get_device_properties( 13 | torch.cuda.current_device() 14 | ).total_memory 15 | # try to allocate all memory 16 | while True: 17 | try: 18 | t = torch.empty( 19 | total_memory, 20 | dtype=torch.uint8, 21 | device=torch.cuda.current_device(), 22 | ) 23 | except RuntimeError as e: 24 | if "CUDA out of memory" in str(e): 25 | total_memory -= 128 * 1024 * 1024 # 128MB 26 | continue 27 | else: 28 | # not an OOM error 29 | raise e 30 | break 31 | if custom_allocator: 32 | allocator = get_allocator() 33 | current_memory = allocator.current_allocated_cuda_memory() 34 | else: 35 | current_memory = torch.cuda.memory_allocated() 36 | del t 37 | # # free the memory 38 | # torch.cuda.empty_cache() 39 | # # allocate again, but reduce 1GB for other stuff 40 | # total_memory = int(total_memory - 1e9) 41 | # t = torch.empty( 42 | # total_memory, dtype=torch.uint8, device=torch.cuda.current_device() 43 | # ) 44 | # current memory should be higher than the maximum 45 | # memory limit of the device. currently adjusted by hand 46 | # but may be automated. 47 | return total_memory, current_memory 48 | -------------------------------------------------------------------------------- /dynapipe/pipe/executor.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from typing import Type 5 | 6 | from dynapipe.utils.logger import create_logger 7 | 8 | from .instructions import * # noqa: F403 9 | 10 | 11 | def _handle_free_buffer(exec: "PipelineExecutor", instr: FreeBuffer): 12 | # free buffer just removes the buffer from the buffer slots 13 | buffer_ids = instr.buffer_ids 14 | for buffer_id in buffer_ids: 15 | exec.buffer_slots[buffer_id] = None 16 | 17 | 18 | class PipelineExecutor: 19 | """ 20 | Executes the dynamic pipeline according to pipeline instructions. 21 | """ 22 | 23 | def __init__( 24 | self, 25 | dp_rank: int = None, 26 | pp_rank: int = None, 27 | synchronous: bool = False, 28 | ): 29 | # rank is optional, only used for debugging 30 | self.buffer_slots = [] 31 | self._instruction_handlers = {} 32 | self.dp_rank = dp_rank 33 | self.pp_rank = pp_rank 34 | # if synchronous, cuda device synchronization is performed after 35 | # each instruction, mainly used for debugging 36 | self.synchronous = synchronous 37 | # register default handlers 38 | self.register_handler(FreeBuffer, _handle_free_buffer) 39 | self.logger = create_logger( 40 | "PipelineExecutor", 41 | prefix=f"dRank {dp_rank} pRank {pp_rank}", 42 | log_file=f"executor/dr{dp_rank}_pr{pp_rank}.log", 43 | ) 44 | self.instr_index = 0 45 | self.current_iteration = None 46 | self.is_last_micro_batch = False 47 | 48 | def execute(self, execution_plan: ExecutionPlan, iteration=None): 49 | self.execution_plan = execution_plan 50 | self.buffer_slots = [None] * execution_plan.num_pipe_buffers 51 | if iteration is not None: 52 | self.current_iteration = iteration 53 | self.logger.debug("Executing iteration %d", iteration) 54 | for instr_index, instruction in enumerate(execution_plan.instructions): 55 | self.logger.debug("Executing instruction: %s", instruction) 56 | self.instr_index = instr_index 57 | if ( 58 | instruction.microbatch is not None 59 | and instruction.microbatch 60 | == execution_plan.num_micro_batches - 1 61 | ): 62 | self.is_last_micro_batch = True 63 | else: 64 | self.is_last_micro_batch = False 65 | self._execute_instruction(instruction) 66 | if iteration is not None: 67 | self.logger.debug("Finished executing iteration %d", iteration) 68 | 69 | def register_handler( 70 | self, instruction_type: Type[PipeInstruction], handler 71 | ): 72 | if not issubclass(instruction_type, PipeInstruction): 73 | raise TypeError( 74 | f"Instruction type must be a subclass of PipeInstruction, " 75 | f"got {instruction_type.__name__}" 76 | ) 77 | if instruction_type in self._instruction_handlers: 78 | raise ValueError( 79 | f"Instruction handler for {instruction_type.__name__} " 80 | "already registered." 81 | ) 82 | self._instruction_handlers[instruction_type] = handler 83 | 84 | @classmethod 85 | def _get_leaf_subclasses(cls, instruction_type: Type[PipeInstruction]): 86 | subclasses = instruction_type.__subclasses__() 87 | if subclasses: 88 | for subclass in subclasses: 89 | yield from cls._get_leaf_subclasses(subclass) 90 | else: 91 | yield instruction_type 92 | 93 | @classmethod 94 | def get_all_needed_handlers(cls): 95 | needed_handlers = set( 96 | cls._get_leaf_subclasses(PipeInstruction) 97 | ).difference([FreeBuffer]) 98 | return [x.__name__ for x in needed_handlers] 99 | 100 | def check_all_handlers_registered(self): 101 | for instruction_type in self._get_leaf_subclasses(PipeInstruction): 102 | if instruction_type not in self._instruction_handlers: 103 | raise ValueError( 104 | "No handler registered for instruction " 105 | f"{instruction_type.__name__}" 106 | ) 107 | 108 | def register_synchronization_handler(self, handler): 109 | self.synchronize = handler 110 | 111 | def synchronize(self): 112 | raise NotImplementedError("Synchronization handler not registered") 113 | 114 | def _execute_instruction(self, instruction: PipeInstruction): 115 | handler = self._instruction_handlers.get(type(instruction)) 116 | if handler is None: 117 | raise ValueError( 118 | "No handler registered for instruction " 119 | f"{type(instruction).__name__}" 120 | ) 121 | handler(self, instruction) 122 | if self.synchronous: 123 | self.synchronize() 124 | -------------------------------------------------------------------------------- /dynapipe/pipe/instruction_optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from collections import defaultdict 5 | from dataclasses import dataclass 6 | 7 | from .instructions import * # noqa: F403 8 | 9 | 10 | @dataclass 11 | class _Buffer: 12 | slot: int 13 | microbatch: int 14 | stage: int 15 | shape: Tuple[int, ...] 16 | life_start: int 17 | life_end: int 18 | 19 | 20 | def _is_forward(instr): 21 | return isinstance( 22 | instr, 23 | ( 24 | ForwardPass, 25 | SendActivationStart, 26 | SendActivationFinish, 27 | RecvActivationStart, 28 | RecvActivationFinish, 29 | LoadInput, 30 | ), 31 | ) 32 | 33 | 34 | def _is_recv_instr(instr): 35 | return isinstance(instr, (RecvActivationStart, RecvGradStart)) 36 | 37 | 38 | def _is_send_instr(instr): 39 | return isinstance(instr, (SendActivationStart, SendGradStart)) 40 | 41 | 42 | def _is_compute_instr(instr): 43 | return isinstance(instr, (ForwardPass, BackwardPass)) 44 | 45 | 46 | def _get_key(instr: PipeInstruction): 47 | return (instr.microbatch, instr.stage, _is_forward(instr)) 48 | 49 | 50 | def _fw_stage_to_bw_stage(stage: int, n_stages: int): 51 | return n_stages - 1 - stage 52 | 53 | 54 | class InstructionOptimizer: 55 | """ 56 | Inject buffer allocation/free and communication finish 57 | ops into the pipeline instructions. 58 | """ 59 | 60 | def __init__( 61 | self, 62 | per_worker_instructions: List[List[PipeInstruction]], 63 | n_stages: int, 64 | ): 65 | self.per_worker_instructions = per_worker_instructions 66 | self.n_stages = n_stages 67 | 68 | def _inject_comm_finish_instrs(self, instrs: List[PipeInstruction]): 69 | # We assume that each rank has two communication streams, 70 | # one for communication with the previous rank and one for 71 | # the next rank. This gives better communication overlap 72 | # without the possibility to deadlock. 73 | # 74 | # For each RecvXXXStart, we need a RecvXXXFinish instr before 75 | # the instruction that uses the data, which is identified by 76 | # the corresponding microbatch and stage id. 77 | # 78 | # For each SendXXXStart, there is a trade-off between freeing the 79 | # memory early and unnecessary waiting if using static location for 80 | # SendXXXFinish. Therefore we dynamically query if the send is complete 81 | # during execution, and SendXXXFinish is added as late as possible, 82 | # only serving as constraints for correctness (in case dynamic query 83 | # fails). 84 | # We add SendActivationFinish only before the corresponding backward 85 | # pass, at which point the send must have completed. All SendGradFinish 86 | # are added at the end of the iteration. 87 | 88 | instr_map: Dict[ 89 | Type[CommunicationStartInstruction], 90 | Type[CommunicationFinishInsturction], 91 | ] = { 92 | SendActivationStart: SendActivationFinish, 93 | RecvActivationStart: RecvActivationFinish, 94 | SendGradStart: SendGradFinish, 95 | RecvGradStart: RecvGradFinish, 96 | } 97 | _prepend_map = {} 98 | accumulated_send_activation_finish_instrs = defaultdict(list) 99 | accumulated_send_grad_finish_instrs = [] 100 | new_instrs = [] 101 | for instr in instrs: 102 | if _is_recv_instr(instr): 103 | key = _get_key(instr) 104 | assert key not in _prepend_map 105 | _prepend_map[key] = instr 106 | elif _is_send_instr(instr): 107 | instr: CommunicationStartInstruction 108 | # get the corresponding finish instr 109 | finish_instr = instr_map[type(instr)]( 110 | instr.microbatch, instr.stage, instr.peer 111 | ) 112 | # append existing send finish instrs 113 | # new_instrs.extend(accumulated_send_finish_instrs[instr.peer].copy()) 114 | # accumulated_send_finish_instrs[instr.peer].clear() 115 | if isinstance(instr, SendActivationStart): 116 | accumulated_send_activation_finish_instrs[ 117 | ( 118 | instr.microbatch, 119 | _fw_stage_to_bw_stage(instr.stage, self.n_stages), 120 | ) 121 | ].append(finish_instr) 122 | elif isinstance(instr, SendGradStart): 123 | accumulated_send_grad_finish_instrs.append(finish_instr) 124 | else: 125 | raise RuntimeError(f"Unknown send instr: {instr}") 126 | elif _is_compute_instr(instr): 127 | key = _get_key(instr) 128 | if key in _prepend_map: 129 | start_instr: CommunicationStartInstruction = _prepend_map[ 130 | key 131 | ] 132 | new_instrs.append( 133 | instr_map[type(start_instr)]( 134 | start_instr.microbatch, 135 | start_instr.stage, 136 | start_instr.peer, 137 | ) 138 | ) 139 | if not _is_forward(instr): 140 | # append existing send activation finish instrs 141 | new_instrs.extend( 142 | accumulated_send_activation_finish_instrs[ 143 | (instr.microbatch, instr.stage) 144 | ].copy() 145 | ) 146 | accumulated_send_activation_finish_instrs[ 147 | (instr.microbatch, instr.stage) 148 | ].clear() 149 | new_instrs.append(instr) 150 | # append any remaining send finish instrs 151 | for ( 152 | accumulated_send_finish_instrs 153 | ) in accumulated_send_activation_finish_instrs.values(): 154 | assert len(accumulated_send_finish_instrs) == 0 155 | new_instrs.extend(accumulated_send_grad_finish_instrs) 156 | return new_instrs 157 | 158 | def _allocate_buffers(self, instrs: List[PipeInstruction]): 159 | # allcate: create new tensors (e.g. torch.zeros) 160 | # assign: assign a tensor to a buffer slot 161 | # Current assumptions: 162 | # 1. RecvXXXStart allocates its own buffers and writes to buffer_ids, 163 | # so we are only assigning buffer slots here. This can be optimized 164 | # by allocating buffers in advance if memory allocation issues 165 | # arise. 166 | # 2. ForwardPass and BackwardPass reads and writes the same buffer_ids. 167 | # SendXXXStart only reads but do not write to buffer_ids. 168 | # RecvXXXStart creates new buffers. SendXXXFinish and RecvXXXFinish 169 | # do not read or write to buffer_ids. 170 | buffer_slots: List[_Buffer] = [] 171 | key_to_buffers: Dict[Any, List[_Buffer]] = defaultdict(list) 172 | 173 | def _allocate_buffer_slot( 174 | instr: BufferInstruction, shape, current_idx 175 | ) -> _Buffer: 176 | # find the first available buffer slot 177 | slot = len(buffer_slots) 178 | buffer = _Buffer( 179 | slot, instr.microbatch, instr.stage, shape, current_idx, None 180 | ) 181 | buffer_slots.append(buffer) 182 | return buffer 183 | 184 | for instr_idx, instr in enumerate(instrs): 185 | if isinstance( 186 | instr, 187 | ( 188 | ForwardPass, 189 | BackwardPass, 190 | SendActivationStart, 191 | SendGradStart, 192 | ), 193 | ): 194 | key = _get_key(instr) 195 | if isinstance(instr, BackwardPass) and instr.first_bw_layer: 196 | # first backward layer directly uses forward pass buffers 197 | assert key not in key_to_buffers 198 | fw_key = (instr.microbatch, instr.stage - 1, True) 199 | key_to_buffers[key] = key_to_buffers[fw_key].copy() 200 | assert ( 201 | key in key_to_buffers 202 | ), f"buffer not allocated for {instr}" 203 | buffers = key_to_buffers[key] 204 | # we only allow dropping buffers 205 | # allocation needs explicit instrs 206 | assert len(buffers) >= len(instr.buffer_shapes), ( 207 | f"buffer allocation mismatch for {instr}, " 208 | f"expected less than {len(instr.buffer_shapes)}, " 209 | f"got {len(buffers)}" 210 | ) 211 | for buffer in buffers: 212 | instr.buffer_ids.append(buffer.slot) 213 | buffer.life_end = instr_idx 214 | elif isinstance( 215 | instr, (RecvActivationStart, RecvGradStart, LoadInput) 216 | ): 217 | # allocate new buffers 218 | key = _get_key(instr) 219 | for shape in instr.buffer_shapes: 220 | buffer = _allocate_buffer_slot(instr, shape, instr_idx) 221 | instr.buffer_ids.append(buffer.slot) 222 | key_to_buffers[key].append(buffer) 223 | 224 | # now insert buffer free instructions 225 | new_instrs = [] 226 | buffers_freed_at_idx = defaultdict(list) 227 | for buffer in buffer_slots: 228 | assert buffer.life_end is not None, f"buffer {buffer} not used. " 229 | buffers_freed_at_idx[buffer.life_end].append(buffer.slot) 230 | for instr_idx, instr in enumerate(instrs): 231 | new_instrs.append(instr) 232 | if instr_idx in buffers_freed_at_idx: 233 | new_instrs.append( 234 | FreeBuffer(buffer_ids=buffers_freed_at_idx[instr_idx]) 235 | ) 236 | return new_instrs, len(buffer_slots) 237 | 238 | def optimize(self): 239 | result_instrs = [] 240 | result_num_buffers = [] 241 | for instrs in self.per_worker_instructions: 242 | instrs = self._inject_comm_finish_instrs(instrs) 243 | instrs, num_buffers = self._allocate_buffers(instrs) 244 | # check all needed buffers are allocated 245 | for instr in instrs: 246 | if isinstance( 247 | instr, 248 | ( 249 | ForwardPass, 250 | BackwardPass, 251 | SendActivationStart, 252 | SendGradStart, 253 | RecvActivationStart, 254 | RecvGradStart, 255 | LoadInput, 256 | ), 257 | ): 258 | assert len(instr.buffer_ids) >= len( 259 | instr.buffer_shapes 260 | ), f"buffer allocation mismatch for {instr}, " 261 | f"expected {len(instr.buffer_shapes)}, " 262 | f"got {len(instr.buffer_ids)}" 263 | result_instrs.append(instrs) 264 | result_num_buffers.append(num_buffers) 265 | return result_instrs, result_num_buffers 266 | -------------------------------------------------------------------------------- /dynapipe/pipe/kv_redis.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import atexit 5 | import datetime 6 | import os 7 | import subprocess 8 | import time 9 | 10 | import redis 11 | 12 | REDIS_CMD = os.environ.get("DYNAPIPE_REDIS_CMD", "redis-server") 13 | KVREDIS_POLLING_INTERVAL = float( 14 | os.environ.get("DYNAPIPE_KVREDIS_POLLING_INTERVAL", "0.05") 15 | ) 16 | KVREDIS_CONNECT_TIMEOUT = float( 17 | os.environ.get("DYNAPIPE_KVREDIS_CONNECT_TIMEOUT", 30) 18 | ) 19 | 20 | 21 | class RedisKVStore(object): 22 | # a blocking redis client 23 | def __init__(self, host, port, is_master=False): 24 | self.is_master = is_master 25 | self.host = host 26 | self.port = port 27 | if self.is_master: 28 | self.server = self._run_redis_server() 29 | # wait for redis server to start 30 | t = time.time() 31 | while True: 32 | try: 33 | self.client = redis.Redis(host=host, port=port, db=0) 34 | self.client.ping() 35 | break 36 | except redis.exceptions.ConnectionError: 37 | time.sleep(KVREDIS_POLLING_INTERVAL) 38 | if time.time() - t > KVREDIS_CONNECT_TIMEOUT: 39 | raise RuntimeError( 40 | "WARNING: Cannot connect to KV Server. " 41 | "Is DYNAPIPE_KV_HOST and " 42 | "DYNAPIPE_KV_PORT set correctly?" 43 | ) 44 | continue 45 | # register cleanup 46 | atexit.register(self.__del__) 47 | 48 | def __del__(self): 49 | if self.is_master: 50 | if self.server.poll() is not None: 51 | return 52 | self.server.send_signal(subprocess.signal.SIGINT) 53 | self.server.wait() 54 | 55 | def _run_redis_server(self): 56 | # run a redis server 57 | p = subprocess.Popen( 58 | [ 59 | REDIS_CMD, 60 | "--save", 61 | "", 62 | "--port", 63 | str(self.port), 64 | "--bind", 65 | str(self.host), 66 | ], 67 | shell=False, 68 | stdout=subprocess.DEVNULL, 69 | stderr=subprocess.STDOUT, 70 | ) 71 | return p 72 | 73 | def wait(self, keys, timeout=None): 74 | # wait for a key to be set 75 | time_start = datetime.datetime.now() 76 | if not isinstance(keys, (list, tuple)): 77 | keys = [keys] 78 | while True: 79 | if self.client.exists(*keys): 80 | break 81 | if ( 82 | timeout is not None 83 | and datetime.datetime.now() - time_start > timeout 84 | ): 85 | # match torch kvstore behavior 86 | raise RuntimeError("Timeout") 87 | time.sleep(KVREDIS_POLLING_INTERVAL) 88 | 89 | def get(self, key, wait=True): 90 | if wait: 91 | self.wait(key) 92 | return self.client.get(key) 93 | 94 | def set(self, key, value: str, logger=None): 95 | # match torch kvstore behavior 96 | value_bytes = value.encode() 97 | self.client.set(key, value_bytes) 98 | if logger: 99 | logger.debug("KVStore: set {} to {}".format(key, value)) 100 | 101 | def add(self, key, value: int): 102 | # match torch kvstore behavior 103 | return self.client.incr(key, value) 104 | 105 | def delete_key(self, key): 106 | return self.client.delete(key) 107 | -------------------------------------------------------------------------------- /dynapipe/pipe/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from typing import List 5 | 6 | from dynapipe.model import DynaPipeCluster, TransformerModelSpec 7 | 8 | from .instructions import * # noqa: F403 9 | 10 | 11 | def validate_device_assignment( 12 | model_spec: TransformerModelSpec, 13 | cluster_spec: DynaPipeCluster, 14 | device_assignment: List[int], 15 | ): 16 | """ 17 | Validate device assignment and detect device assignment type. 18 | Args: 19 | device_assignment: List of device ids for each layer. 20 | """ 21 | appeared_devices = set() 22 | for device in device_assignment: 23 | if device not in appeared_devices: 24 | # new device 25 | assert device == len(appeared_devices), ( 26 | "Devices must appear in indexed order. " 27 | "e.g. [0, 1, 2, 3] is valid, " 28 | "[0, 1, 3, 2] is not valid." 29 | ) 30 | appeared_devices.add(device) 31 | n_devices = len(appeared_devices) 32 | assert n_devices == cluster_spec.n_devices, ( 33 | "Number of devices used in device assignment " 34 | "must be equal to number of devices in cluster spec." 35 | ) 36 | virtual_layer_to_actual_layers = [[]] 37 | virtual_layer_devices = [0] 38 | last_device = 0 39 | for device in device_assignment: 40 | if device == last_device: 41 | virtual_layer_to_actual_layers[-1].append(device) 42 | else: 43 | virtual_layer_to_actual_layers.append([device]) 44 | virtual_layer_devices.append(device) 45 | last_device = device 46 | n_actual_layers_per_virtual_layer = len(virtual_layer_to_actual_layers[0]) 47 | for virtual_layer in virtual_layer_to_actual_layers: 48 | n_encoder_layers_in_virtual_layer = len( 49 | [ 50 | layer 51 | for layer in virtual_layer 52 | if layer < model_spec.n_encoder_layers 53 | ] 54 | ) 55 | n_decoder_layers_in_virtual_layer = ( 56 | len(virtual_layer) - n_encoder_layers_in_virtual_layer 57 | ) 58 | if n_encoder_layers_in_virtual_layer > 0: 59 | assert ( 60 | len(virtual_layer) == n_encoder_layers_in_virtual_layer 61 | ), "Number of layers on each virtual layer must be the same." 62 | if n_decoder_layers_in_virtual_layer > 0: 63 | assert ( 64 | len(virtual_layer) == n_decoder_layers_in_virtual_layer 65 | ), "Number of layers on each virtual layer must be the same." 66 | if len(device_assignment) != n_actual_layers_per_virtual_layer: 67 | # only check if we are actually using pipeline parallelism 68 | assert ( 69 | model_spec.n_encoder_layers % n_actual_layers_per_virtual_layer 70 | == 0 71 | ), ( 72 | f"Number of encoder layers ({model_spec.n_encoder_layers}) " 73 | f"must be divisible by number of layers on each virtual layer " 74 | f"({n_actual_layers_per_virtual_layer})." 75 | ) 76 | assert ( 77 | model_spec.n_decoder_layers % n_actual_layers_per_virtual_layer 78 | == 0 79 | ), ( 80 | f"Number of decoder layers ({model_spec.n_decoder_layers}) " 81 | f"must be divisible by number of layers on each virtual layer " 82 | f"({n_actual_layers_per_virtual_layer})." 83 | ) 84 | # classify device assignment into linear, interleaved and other 85 | device_assignment_type = "other" 86 | if len(virtual_layer_devices) == n_devices: 87 | if virtual_layer_devices == list(range(n_devices)): 88 | device_assignment_type = "linear" 89 | else: 90 | n_chunks = len(virtual_layer_devices) // n_devices 91 | interleaved_assignment = list(range(n_devices)) * n_chunks 92 | if interleaved_assignment == virtual_layer_devices: 93 | device_assignment_type = "interleaved" 94 | if ( 95 | device_assignment_type == "interleaved" 96 | and model_spec.n_decoder_layers == 0 97 | ): 98 | # interleaved device assignment is not supported for decoder only 99 | # models 100 | raise NotImplementedError( 101 | "Interleaved device assignment is not supported " 102 | "for decoder only models." 103 | ) 104 | valid_schedule_methods = ["wait-free-cyclic"] 105 | if device_assignment_type == "linear" and n_devices > 1: 106 | valid_schedule_methods.append("1F1B") 107 | elif device_assignment_type == "interleaved": 108 | valid_schedule_methods.append("interleaved-1F1B") 109 | n_chunks_per_device = len(virtual_layer_devices) // n_devices 110 | return ( 111 | device_assignment_type, 112 | valid_schedule_methods, 113 | n_actual_layers_per_virtual_layer, 114 | n_chunks_per_device, 115 | ) 116 | 117 | 118 | def check_deadlock(eps: List[ExecutionPlan]): 119 | # validate if the instruction sequence will result in a deadlock 120 | # filter out all communication instructions 121 | _INSTR_TYPE_MAP = { 122 | SendActivationStart: RecvActivationStart, 123 | SendGradStart: RecvGradStart, 124 | RecvActivationStart: SendActivationStart, 125 | RecvGradStart: SendGradStart, 126 | } 127 | comm_ops_per_exec = [] 128 | for ep in eps: 129 | instrs = ep.instructions 130 | comm_ops = [] 131 | for instr in instrs: 132 | if isinstance( 133 | instr, 134 | ( 135 | SendActivationStart, 136 | SendGradStart, 137 | RecvActivationStart, 138 | RecvGradStart, 139 | ), 140 | ): 141 | comm_ops.append(instr) 142 | comm_ops_per_exec.append(comm_ops) 143 | 144 | def _alert_deadlock(exec_idx, peer_idx, current_instr_idx, peer_instr_idx): 145 | additonal_info = "" 146 | if peer_idx is None: 147 | additonal_info += ( 148 | "Executor {exec_idx} " 149 | "has unfinished instruction " 150 | f"{comm_ops_per_exec[exec_idx][current_instr_idx]}. \n" 151 | ) 152 | else: 153 | if current_instr_idx < len( 154 | comm_ops_per_exec[exec_idx] 155 | ) and peer_instr_idx < len(comm_ops_per_exec[peer_idx]): 156 | instr_order_str = "\n\t".join( 157 | [str(x) for x in comm_ops_per_exec[exec_idx]] 158 | ) 159 | peer_instr_order_str = "\n\t".join( 160 | [str(x) for x in comm_ops_per_exec[peer_idx]] 161 | ) 162 | additonal_info += ( 163 | "Mismatched instructions " 164 | f"{comm_ops_per_exec[exec_idx][current_instr_idx]}\n" 165 | "\t\t\tand " 166 | f"{comm_ops_per_exec[peer_idx][peer_instr_idx]}.\n" 167 | ) 168 | else: 169 | additonal_info += ( 170 | "No matching instruction for " 171 | f"{comm_ops_per_exec[exec_idx][current_instr_idx]}. \n" 172 | ) 173 | additonal_info += ( 174 | f"Instruction order: \n\t{instr_order_str}.\n" 175 | f"Peer instruction order: \n\t{peer_instr_order_str}." 176 | ) 177 | raise RuntimeError( 178 | "[INTERNAL ERROR] " 179 | f"Deadlock detected between exec {exec_idx} " 180 | f"(current) and {peer_idx} (peer).\n" + additonal_info 181 | ) 182 | 183 | current_instrs_per_exec = [0] * len(eps) 184 | # eliminate matching instructions 185 | while True: 186 | progress = False 187 | for exec_idx, instr_idx in enumerate(current_instrs_per_exec): 188 | if instr_idx >= len(comm_ops_per_exec[exec_idx]): 189 | continue 190 | instr = comm_ops_per_exec[exec_idx][instr_idx] 191 | # check if there is a matching instruction on peer exec 192 | peer = instr.peer 193 | peer_instr_idx = current_instrs_per_exec[peer] 194 | if peer_instr_idx >= len(comm_ops_per_exec[peer]): 195 | # deadlock 196 | _alert_deadlock(exec_idx, peer, instr_idx, peer_instr_idx) 197 | peer_instr = comm_ops_per_exec[peer][peer_instr_idx] 198 | # check peer 199 | if peer_instr.peer != exec_idx: 200 | # not current receiving from/sending to us 201 | continue 202 | # if peer is receiving/sending to us, the instructions must match 203 | # check if the instruction type matches 204 | if not isinstance(instr, _INSTR_TYPE_MAP[type(peer_instr)]): 205 | # deadlock 206 | _alert_deadlock(exec_idx, peer, instr_idx, peer_instr_idx) 207 | # check shape 208 | if instr.buffer_shapes != peer_instr.buffer_shapes: 209 | # deadlock 210 | _alert_deadlock(exec_idx, peer, instr_idx, peer_instr_idx) 211 | # check passed, increment instruction index on both execs 212 | current_instrs_per_exec[exec_idx] += 1 213 | current_instrs_per_exec[peer] += 1 214 | progress = True 215 | if not progress: 216 | break 217 | # check if all instructions are consumed 218 | for exec_idx, instr_idx in enumerate(current_instrs_per_exec): 219 | if instr_idx < len(comm_ops_per_exec[exec_idx]): 220 | _alert_deadlock(exec_idx, None, instr_idx, None) 221 | -------------------------------------------------------------------------------- /dynapipe/schedule_opt/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from .cyclic_schedule import CyclicScheduler 5 | from .fifo_schedule import FIFOScheduler 6 | from .ofob_schedule import OFOBSchedulerRegistry as reg 7 | from .wait_free_cyclic_schedule import WaitFreeCyclicScheduler 8 | 9 | AVAILABLE_SCHEDULERS = { 10 | "cyclic": CyclicScheduler, 11 | "fifo": FIFOScheduler, 12 | "wait-free-cyclic": WaitFreeCyclicScheduler, 13 | "1F1B": reg.get_scheduler_factory(placement_type="linear"), 14 | "relaxed-1F1B": reg.get_scheduler_factory(strictness="relaxed"), 15 | "interleaved-1F1B": reg.get_scheduler_factory( 16 | placement_type="interleaved" 17 | ), 18 | "interleaved-relaxed-1F1B": reg.get_scheduler_factory( 19 | strictness="interleaved-relaxed" 20 | ), 21 | "interleaved-cyclic-1F1B": reg.get_scheduler_factory( 22 | placement_type="interleaved", dependency_policy="cyclic" 23 | ), 24 | } 25 | 26 | 27 | def get_available_schedulers(): 28 | return AVAILABLE_SCHEDULERS.keys() 29 | 30 | 31 | def get_scheduler_class(scheduler_name): 32 | if scheduler_name not in AVAILABLE_SCHEDULERS: 33 | raise ValueError( 34 | f"Scheduler {scheduler_name} not available." 35 | f"Available schedulers: {get_available_schedulers()}" 36 | ) 37 | return AVAILABLE_SCHEDULERS[scheduler_name] 38 | -------------------------------------------------------------------------------- /dynapipe/schedule_opt/cyclic_schedule.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import logging 5 | from collections import defaultdict 6 | from typing import Dict, List, Optional, Tuple 7 | 8 | from .schedule_common import ( 9 | ExecutorIndex, 10 | ScheduleExecutor, 11 | ScheduleOperation, 12 | Scheduler, 13 | SchedulerMinibatchSpec, 14 | ) 15 | 16 | 17 | class CyclicScheduler(Scheduler): 18 | def __init__( 19 | self, 20 | minibatch_spec: SchedulerMinibatchSpec, 21 | include_memory_stats: bool = True, 22 | memory_limit: float = float("inf"), 23 | max_otf_microbatches: int = int(1e6), 24 | logger: Optional[logging.Logger] = None, 25 | ) -> None: 26 | super().__init__( 27 | minibatch_spec, 28 | include_memory_stats, 29 | memory_limit, 30 | logger=logger, 31 | ) 32 | self.max_otf_microbatches = max_otf_microbatches 33 | self._initialize() 34 | # calculate cycle time 35 | executor_exec_times = [] 36 | max_stage_exec_times_acorss_mb = [] 37 | for layer in range(self.n_flattened_stages): 38 | max_time = 0 39 | for microbatch in self.minibatch_spec.microbatches: 40 | max_time = max( 41 | max_time, microbatch.flattened_exec_times[layer] 42 | ) 43 | max_stage_exec_times_acorss_mb.append(max_time) 44 | for executor in self.executors.items(): 45 | executor_exec_times.append( 46 | sum( 47 | (max_stage_exec_times_acorss_mb[i]) 48 | for i in self.executor2stages[executor] 49 | ) 50 | ) 51 | self.cycle_time = max(executor_exec_times) 52 | self.fast_forward_comm = True 53 | self.executors: Dict[int, CyclicExecutor] 54 | 55 | def _get_executor( 56 | self, 57 | executor_id, 58 | thread_id, 59 | n_orig_layers, 60 | assigned_stages, 61 | is_comm_stage, 62 | include_memory_stats, 63 | memory_limit=float("inf"), 64 | parent_executor=None, 65 | ): 66 | # overrides Scheduler 67 | return CyclicExecutor( 68 | executor_id, 69 | thread_id=thread_id, 70 | n_orig_layers=n_orig_layers, 71 | assigned_stages=assigned_stages, 72 | is_comm_stage=is_comm_stage, 73 | include_memory_stats=include_memory_stats, 74 | memory_limit=memory_limit, 75 | max_otf_microbatches=self.max_otf_microbatches, 76 | parent_executor=parent_executor, 77 | logger=self.logger, 78 | ) 79 | 80 | def _init_executors(self, n_microbatches, **kwargs): 81 | for executor in self.executors.values(): 82 | executor.reset() 83 | return True 84 | 85 | def _inject_microbatches( 86 | self, microbatch_offset: int, n_microbatches: int 87 | ): 88 | for microbatch_id in range( 89 | microbatch_offset, microbatch_offset + n_microbatches 90 | ): 91 | executor = self.executors[ 92 | self.minibatch_spec.flattened_executor_assignment[0] 93 | ] 94 | op = self._get_op(0, microbatch_id) 95 | executor.add_operation(op) 96 | executor.forward_cycle() 97 | 98 | def _on_executed_ops(self, executed_ops: List[ScheduleOperation]): 99 | for op in executed_ops: 100 | if op.flattened_stage < self.n_flattened_stages - 1: 101 | next_stage = op.flattened_stage + 1 102 | next_executor = ( 103 | self.minibatch_spec.flattened_executor_assignment[ 104 | next_stage 105 | ] 106 | ) 107 | next_op = self._get_op(next_stage, op.microbatch) 108 | self.executors[next_executor].add_operation(next_op) 109 | 110 | def _get_global_instance_event(self, name, current_time): 111 | return { 112 | "name": name, 113 | "ph": "i", 114 | "ts": current_time, 115 | "pid": 0, 116 | "tid": 0, 117 | "s": "g", 118 | } 119 | 120 | def _schedule( 121 | self, 122 | ): 123 | n_microbatches = len(self.minibatch_spec.microbatches) 124 | status = self._init_executors(n_microbatches) 125 | if not status: 126 | return None, None 127 | self._inject_microbatches(0, n_microbatches) 128 | trace_events = self._get_trace_events() 129 | operator_execution_order: Dict[ 130 | ExecutorIndex, list[ScheduleOperation] 131 | ] = defaultdict(list) 132 | current_time = 0 133 | executor_end_times = [] 134 | # for _ in range(n_microbatches + self.n_layers): 135 | while True: 136 | has_progress = False 137 | for executor in self.executors.values(): 138 | end_time, executed_ops, events = executor.exec_cycle( 139 | current_time 140 | ) 141 | executor_end_times.append(end_time) 142 | if executed_ops: 143 | self._on_executed_ops(executed_ops) 144 | trace_events["traceEvents"].extend(events) 145 | operator_execution_order[ 146 | ExecutorIndex(executor.executor_id, executor.thread_id) 147 | ] += executed_ops 148 | has_progress = True 149 | for executor in self.executors.values(): 150 | executor.forward_cycle() 151 | if self.fast_forward_comm: 152 | # execute an extra cycle to fast forward communication 153 | for executor in self.executors.values(): 154 | if executor.is_comm_stage: 155 | end_time, executed_ops, events = executor.exec_cycle( 156 | current_time 157 | ) 158 | executor_end_times.append(end_time) 159 | if executed_ops: 160 | self._on_executed_ops(executed_ops) 161 | trace_events["traceEvents"].extend(events) 162 | operator_execution_order[ 163 | ExecutorIndex( 164 | executor.executor_id, executor.thread_id 165 | ) 166 | ] += executed_ops 167 | has_progress = True 168 | for executor in self.executors.values(): 169 | executor.forward_cycle() 170 | if not has_progress: 171 | break 172 | current_time = max(executor_end_times) 173 | trace_events["traceEvents"].append( 174 | self._get_global_instance_event("Cycle ended", current_time) 175 | ) 176 | self.makespan = max(executor_end_times) 177 | for executor in self.executors.values(): 178 | for buffer in executor.buffers.values(): 179 | if len(buffer) != 0: 180 | # unable to schedule all operations 181 | self.makespan = -1 182 | return None, None 183 | return trace_events, operator_execution_order 184 | 185 | def get_operator_order( 186 | self, 187 | ): 188 | _, operator_execution_order = self._schedule() 189 | return operator_execution_order 190 | 191 | def schedule(self): 192 | trace_events, _ = self._schedule() 193 | return trace_events 194 | 195 | 196 | class CyclicExecutor(ScheduleExecutor): 197 | def __init__( 198 | self, 199 | executor_id: int, 200 | thread_id: int, 201 | n_orig_layers: int, 202 | assigned_stages: List[Tuple[int, float, bool]], 203 | is_comm_stage: bool = False, 204 | include_memory_stats: bool = True, 205 | memory_limit: float = float("inf"), 206 | max_otf_microbatches: int = int(1e6), 207 | parent_executor: Optional[ScheduleExecutor] = None, 208 | logger: Optional[logging.Logger] = None, 209 | ) -> None: 210 | super().__init__( 211 | executor_id, 212 | thread_id, 213 | n_orig_layers, 214 | assigned_stages, 215 | is_comm_stage, 216 | include_memory_stats, 217 | parent_executor, 218 | logger, 219 | ) 220 | self.memory_limit = memory_limit 221 | self.max_otf_microbatches = max_otf_microbatches 222 | self.buffers: Dict[int, List[ScheduleOperation]] = {} 223 | self.next_step_buffers: Dict[int, List[ScheduleOperation]] = {} 224 | for flattened_stage_id, _, _ in assigned_stages: 225 | self.buffers[flattened_stage_id] = [] 226 | self.next_step_buffers[flattened_stage_id] = [] 227 | self.exec_order = [] 228 | for i in range(max(len(self.fw_stages), len(self.bw_stages))): 229 | if i < len(self.bw_stages): 230 | self.exec_order.append(self.bw_stages[i]) 231 | if i < len(self.fw_stages): 232 | self.exec_order.append(self.fw_stages[i]) 233 | self.executed_fw_microbatches = 0 234 | self.executed_bw_microbatches = 0 235 | 236 | def reset(self): 237 | super().reset() 238 | for key in self.buffers.keys(): 239 | self.buffers[key] = [] 240 | self.next_step_buffers[key] = [] 241 | self.executed_fw_microbatches = 0 242 | self.executed_bw_microbatches = 0 243 | 244 | def add_operation(self, op: ScheduleOperation): 245 | if op.is_forward: 246 | assert ( 247 | op.flattened_stage in self.fw_stages 248 | ), "Operation {} not in executor".format(op) 249 | self.next_step_buffers[op.flattened_stage].append(op) 250 | else: 251 | assert ( 252 | op.flattened_stage in self.bw_stages 253 | ), "Operation {} not in executor".format(op) 254 | self.next_step_buffers[op.flattened_stage].append(op) 255 | 256 | def forward_cycle(self): 257 | # append next_step_buffer to buffer 258 | for key, ops in self.next_step_buffers.items(): 259 | self.buffers[key] += ops 260 | self.next_step_buffers[key] = [] 261 | 262 | def exec_cycle(self, current_time): 263 | total_exec_time_in_cycle = 0 264 | executed_ops = [] 265 | events = [] 266 | available_ops: List[ScheduleOperation] = [] 267 | for stage_id in self.fw_stages + self.bw_stages: 268 | if len(self.buffers[stage_id]) > 0: 269 | available_ops.append(self.buffers[stage_id][0]) 270 | available_bw_ops = sorted( 271 | [op for op in available_ops if not op.is_forward], 272 | key=lambda x: (x.microbatch, x.flattened_stage), 273 | ) 274 | available_fw_ops = sorted( 275 | [op for op in available_ops if op.is_forward], 276 | key=lambda x: (x.microbatch, x.flattened_stage), 277 | ) 278 | # merge 279 | available_ops: List[ScheduleOperation] = [] 280 | for i in range(max(len(available_fw_ops), len(available_bw_ops))): 281 | if i < len(available_bw_ops): 282 | available_ops.append(available_bw_ops[i]) 283 | if i < len(available_fw_ops): 284 | available_ops.append(available_fw_ops[i]) 285 | for op in available_ops: 286 | # test if executing this op will exceed memory limit 287 | if ( 288 | not self.is_comm_stage 289 | and op.is_forward 290 | and op.microbatch > self.executed_fw_microbatches 291 | and ( 292 | self.current_memory + op.peak_memory > self.memory_limit 293 | or self.executed_fw_microbatches 294 | - self.executed_bw_microbatches 295 | >= self.max_otf_microbatches 296 | ) 297 | ): 298 | # skip this op 299 | continue 300 | event = self.get_exec_event( 301 | op, 302 | current_time, 303 | op.exec_time, 304 | ) 305 | events.append(event) 306 | if not self.is_comm_stage: 307 | # we add two memory events for each op 308 | # one for the peak memory usage during the op 309 | # one for the stored memory usage after the op 310 | peak_time = current_time + op.exec_time / 2 311 | finish_time = current_time + op.exec_time 312 | memory_events = self.update_memory( 313 | peak_time, op.peak_memory, finish_time, op.stored_memory 314 | ) 315 | if self.include_memory_stats: 316 | events += memory_events 317 | executed_ops.append(op) 318 | total_exec_time_in_cycle += op.exec_time 319 | current_time += op.exec_time 320 | if op.is_forward: 321 | self.fw_count += 1 322 | if op.microbatch > self.executed_fw_microbatches: 323 | self.executed_fw_microbatches = op.microbatch 324 | else: 325 | self.bw_count += 1 326 | if op.microbatch > self.executed_bw_microbatches: 327 | self.executed_bw_microbatches = op.microbatch 328 | self.buffers[op.flattened_stage].pop(0) 329 | return current_time, executed_ops, events 330 | -------------------------------------------------------------------------------- /dynapipe/schedule_opt/fifo_schedule.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import logging 5 | from typing import List, Optional, Tuple 6 | 7 | from dynapipe.schedule_opt.wait_free_schedule import WaitFreeScheduler 8 | 9 | from .schedule_common import ScheduleExecutor, ScheduleOperation 10 | from .wait_free_schedule import WaitFreeExecutor 11 | 12 | 13 | class FIFOExecutor(WaitFreeExecutor): 14 | def __init__( 15 | self, 16 | executor_id: int, 17 | thread_id: int, 18 | n_orig_layers: int, 19 | assigned_stages: List[Tuple[int, float, bool]], 20 | is_comm_stage: bool = False, 21 | include_memory_stats: bool = True, 22 | parent_executor: Optional[ScheduleExecutor] = None, 23 | logger: Optional[logging.Logger] = None, 24 | ) -> None: 25 | super().__init__( 26 | executor_id, 27 | thread_id, 28 | n_orig_layers, 29 | assigned_stages, 30 | is_comm_stage, 31 | include_memory_stats, 32 | parent_executor, 33 | logger, 34 | ) 35 | self.is_executing = False 36 | self.last_executed_fw = True 37 | 38 | def reset(self): 39 | super().reset() 40 | self.available_queue = [] 41 | 42 | def add_operation(self, op: ScheduleOperation): 43 | if op.is_forward: 44 | assert ( 45 | op.flattened_stage in self.fw_stages 46 | ), "Operation {} not in executor".format(op) 47 | else: 48 | assert ( 49 | op.flattened_stage in self.bw_stages 50 | ), "Operation {} not in executor".format(op) 51 | self.available_queue.append(op) 52 | 53 | def try_execute(self, current_time, ofob=False): 54 | events = [] 55 | if self.available_queue and not self.is_executing: 56 | if ofob: 57 | op = None 58 | for idx, avail_op in enumerate(self.available_queue): 59 | if avail_op.is_forward != self.last_executed_fw: 60 | op = self.available_queue.pop(idx) 61 | break 62 | if op is None: 63 | return current_time, None, events 64 | else: 65 | op = self.available_queue.pop(0) 66 | event = self.get_exec_event( 67 | op, 68 | current_time, 69 | op.exec_time, 70 | ) 71 | events.append(event) 72 | if not self.is_comm_stage and self.include_memory_stats: 73 | # we add two memory events for each op 74 | # one for the peak memory usage during the op 75 | # one for the stored memory usage after the op 76 | peak_time = current_time + op.exec_time / 2 77 | finish_time = current_time + op.exec_time 78 | memory_events = self.update_memory( 79 | peak_time, op.peak_memory, finish_time, op.stored_memory 80 | ) 81 | events += memory_events 82 | finish_time = current_time + op.exec_time 83 | self.is_executing = True 84 | self.last_executed_fw = op.is_forward 85 | return finish_time, op, events 86 | else: 87 | return current_time, None, events 88 | 89 | def finish_execute(self): 90 | self.is_executing = False 91 | 92 | 93 | class FIFOScheduler(WaitFreeScheduler): 94 | def _get_executor( 95 | self, 96 | executor_id, 97 | thread_id, 98 | n_orig_layers, 99 | assigned_stages, 100 | is_comm_stage, 101 | include_memory_stats, 102 | memory_limit=float("inf"), 103 | parent_executor=None, 104 | ): 105 | # overrides Scheduler 106 | return FIFOExecutor( 107 | executor_id, 108 | thread_id=thread_id, 109 | n_orig_layers=n_orig_layers, 110 | assigned_stages=assigned_stages, 111 | is_comm_stage=is_comm_stage, 112 | include_memory_stats=include_memory_stats, 113 | parent_executor=parent_executor, 114 | logger=self.logger, 115 | ) 116 | -------------------------------------------------------------------------------- /dynapipe/schedule_opt/wait_free_cyclic_schedule.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import dataclasses 5 | import logging 6 | from typing import Dict, List, Optional, Tuple 7 | 8 | from .cyclic_schedule import CyclicScheduler 9 | from .schedule_common import ( 10 | ExecutorIndex, 11 | ScheduleExecutor, 12 | ScheduleOperation, 13 | SchedulerMinibatchSpec, 14 | ) 15 | from .wait_free_schedule import WaitFreeExecutor, WaitFreeScheduler 16 | 17 | 18 | class WaitFreeCyclicExecutor(WaitFreeExecutor): 19 | def __init__( 20 | self, 21 | executor_id: int, 22 | thread_id: int, 23 | n_orig_layers: int, 24 | assigned_stages: List[Tuple[int, float, bool]], 25 | is_comm_stage: bool = False, 26 | include_memory_stats: bool = True, 27 | parent_executor: Optional[ScheduleExecutor] = None, 28 | logger: Optional[logging.Logger] = None, 29 | ) -> None: 30 | super().__init__( 31 | executor_id, 32 | thread_id, 33 | n_orig_layers, 34 | assigned_stages, 35 | is_comm_stage, 36 | include_memory_stats, 37 | parent_executor, 38 | logger, 39 | ) 40 | self.available_ops = set() 41 | self.next_op_idx = 0 42 | self.is_executing = False 43 | self.operator_order = None 44 | 45 | def set_operator_order(self, operator_order: List[ScheduleOperation]): 46 | self.operator_order = operator_order 47 | self.debug_print("Operator order: {}".format(operator_order)) 48 | 49 | def reset(self): 50 | super().reset() 51 | self.available_ops.clear() 52 | self.next_op_idx = 0 53 | 54 | def add_operation(self, op: ScheduleOperation): 55 | if op.is_forward: 56 | assert ( 57 | op.flattened_stage in self.fw_stages 58 | ), "Operation {} not in executor".format(op) 59 | else: 60 | assert ( 61 | op.flattened_stage in self.bw_stages 62 | ), "Operation {} not in executor".format(op) 63 | self.available_ops.add(op) 64 | 65 | def try_execute(self, current_time): 66 | assert self.operator_order is not None, "Execution order not set" 67 | events = [] 68 | if not self.is_executing and self.next_op_idx < len( 69 | self.operator_order 70 | ): 71 | self.debug_print( 72 | "Trying to execute next operation: {}".format( 73 | self.operator_order[self.next_op_idx] 74 | ) 75 | ) 76 | next_op = self.operator_order[self.next_op_idx] 77 | if next_op in self.available_ops: 78 | self.next_op_idx += 1 79 | event = self.get_exec_event( 80 | next_op, 81 | current_time, 82 | next_op.exec_time, 83 | ) 84 | events.append(event) 85 | if not self.is_comm_stage and self.include_memory_stats: 86 | # we add two memory events for each op 87 | # one for the peak memory usage during the op 88 | # one for the stored memory usage after the op 89 | peak_time = current_time + next_op.exec_time / 2 90 | finish_time = current_time + next_op.exec_time 91 | memory_events = self.update_memory( 92 | peak_time, 93 | next_op.peak_memory, 94 | finish_time, 95 | next_op.stored_memory, 96 | ) 97 | events += memory_events 98 | finish_time = current_time + next_op.exec_time 99 | self.is_executing = True 100 | return finish_time, next_op, events 101 | # Currently executing or no available operations 102 | return current_time, None, events 103 | 104 | def finish_execute(self): 105 | self.is_executing = False 106 | 107 | 108 | class WaitFreeCyclicScheduler(WaitFreeScheduler): 109 | def __init__( 110 | self, 111 | minibatch_spec: SchedulerMinibatchSpec, 112 | include_memory_stats: bool = True, 113 | memory_limit: float = float("inf"), 114 | max_otf_microbatches: int = int(1e6), 115 | logger: Optional[logging.Logger] = None, 116 | ): 117 | super().__init__( 118 | minibatch_spec, 119 | include_memory_stats, 120 | memory_limit, 121 | logger=logger, 122 | ) 123 | self.cyclic_scheduler = CyclicScheduler( 124 | minibatch_spec, 125 | self.include_memory_stats, 126 | memory_limit=memory_limit, 127 | max_otf_microbatches=max_otf_microbatches, 128 | logger=logger, 129 | ) 130 | self.no_valid_schedule = False 131 | self.executors: Dict[ExecutorIndex, WaitFreeCyclicExecutor] 132 | 133 | def _get_executor( 134 | self, 135 | executor_id, 136 | thread_id, 137 | n_orig_layers, 138 | assigned_stages, 139 | is_comm_stage, 140 | include_memory_stats, 141 | memory_limit=float("inf"), 142 | parent_executor=None, 143 | ): 144 | # overrides Scheduler 145 | return WaitFreeCyclicExecutor( 146 | executor_id, 147 | thread_id=thread_id, 148 | n_orig_layers=n_orig_layers, 149 | assigned_stages=assigned_stages, 150 | is_comm_stage=is_comm_stage, 151 | include_memory_stats=include_memory_stats, 152 | parent_executor=parent_executor, 153 | logger=self.logger, 154 | ) 155 | 156 | def _init_executors(self, n_microbatches, **kwargs): 157 | self.operator_order = self.cyclic_scheduler.get_operator_order() 158 | if self.operator_order is None: 159 | # no valid schedule found 160 | self.no_valid_schedule = True 161 | return False 162 | # overrides WaitFreeScheduler 163 | status = super()._init_executors( 164 | n_microbatches, 165 | **kwargs, 166 | ) 167 | if not status: 168 | return False 169 | for executor_idx, executor in self.executors.items(): 170 | ops = self.operator_order[executor_idx] 171 | 172 | def _create_new_op(op: ScheduleOperation): 173 | # we need to reset the op's executor. since op is immutable, 174 | # we need to create a new op 175 | if op.next_executor is None: 176 | next_executor = None 177 | else: 178 | next_executor_id = ExecutorIndex( 179 | op.next_executor.executor_id, 180 | op.next_executor.thread_id, 181 | ) 182 | next_executor = self.executors[next_executor_id] 183 | return dataclasses.replace(op, next_executor=next_executor) 184 | 185 | executor.set_operator_order([_create_new_op(op) for op in ops]) 186 | return True 187 | 188 | def schedule( 189 | self, warmup=False, warmup_n_microbatches=-1, ofob=False, **kwargs 190 | ): 191 | # overrides WaitFreeScheduler 192 | n_microbatches = len(self.minibatch_spec.microbatches) 193 | if warmup: 194 | if warmup_n_microbatches == -1: 195 | warmup_n_microbatches = min( 196 | self.n_flattened_stages - 1, n_microbatches 197 | ) 198 | self.logger.warning( 199 | "warmup_n_microbatches <= 0, " 200 | "setting it to min(n_layers - 1, n_microbatches) " 201 | f"({warmup_n_microbatches})" 202 | ) 203 | else: 204 | assert ( 205 | warmup_n_microbatches <= n_microbatches 206 | ), "warmup_n_microbatches must be <= n_microbatches" 207 | return super().schedule( 208 | warmup=warmup, 209 | warmup_n_microbatches=warmup_n_microbatches, 210 | ofob=ofob, 211 | **kwargs, 212 | ) 213 | -------------------------------------------------------------------------------- /dynapipe/schedule_opt/wait_free_schedule.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import logging 5 | from dataclasses import dataclass, field 6 | from queue import PriorityQueue 7 | from typing import Dict, List, Optional 8 | 9 | from .schedule_common import ( 10 | ExecutorIndex, 11 | ScheduleExecutor, 12 | ScheduleOperation, 13 | Scheduler, 14 | SchedulerMinibatchSpec, 15 | ) 16 | 17 | 18 | @dataclass(order=True) 19 | class CompleteEvent: 20 | completion_time: float 21 | op: ScheduleOperation = field(compare=False) 22 | executor: ScheduleExecutor = field(compare=False) 23 | 24 | 25 | class WaitFreeExecutor(ScheduleExecutor): 26 | def __init__(self, *args, **kwargs): 27 | super().__init__(*args, **kwargs) 28 | self.available_queue: List[ScheduleOperation] = [] 29 | 30 | def add_operation(self, op: ScheduleOperation): 31 | raise NotImplementedError 32 | 33 | def try_execute(self, current_time): 34 | raise NotImplementedError 35 | 36 | def finish_execute(self): 37 | raise NotImplementedError 38 | 39 | 40 | class WaitFreeScheduler(Scheduler): 41 | def __init__( 42 | self, 43 | minibatch_spec: SchedulerMinibatchSpec, 44 | include_memory_stats: bool = True, 45 | memory_limit: float = float("inf"), 46 | logger: Optional[logging.Logger] = None, 47 | ) -> None: 48 | super().__init__( 49 | minibatch_spec, 50 | include_memory_stats, 51 | memory_limit, 52 | logger=logger, 53 | ) 54 | self._initialize() 55 | self._pending_events: PriorityQueue[CompleteEvent] = PriorityQueue() 56 | self.executors: Dict[ExecutorIndex, WaitFreeExecutor] 57 | 58 | def _get_executor( 59 | self, 60 | executor_id, 61 | thread_id, 62 | n_orig_layers, 63 | assigned_stages, 64 | is_comm_stage, 65 | include_memory_stats, 66 | memory_limit=float("inf"), 67 | ): 68 | # overrides Scheduler 69 | raise NotImplementedError 70 | 71 | def _init_executors(self, n_microbatches, **kwargs): 72 | for executor in self.executors.values(): 73 | executor.reset() 74 | self.communication_executors = [ 75 | executor 76 | for executor in self.executors.values() 77 | if executor.is_comm_stage 78 | ] 79 | self.computation_executors = [ 80 | executor 81 | for executor in self.executors.values() 82 | if not executor.is_comm_stage 83 | ] 84 | return True 85 | 86 | def _inject_microbatches( 87 | self, microbatch_offset: int, n_microbatches: int 88 | ): 89 | for microbatch_id in range( 90 | microbatch_offset, microbatch_offset + n_microbatches 91 | ): 92 | executor = self.executors[ 93 | self.minibatch_spec.flattened_executor_assignment[0] 94 | ] 95 | op = self._get_op(0, microbatch_id) 96 | executor.add_operation(op) 97 | 98 | def _on_op_finish(self, executor: WaitFreeExecutor, op: ScheduleOperation): 99 | executor.finish_execute() 100 | if op.flattened_stage < self.n_flattened_stages - 1: 101 | next_layer = op.flattened_stage + 1 102 | next_executor = self.minibatch_spec.flattened_executor_assignment[ 103 | next_layer 104 | ] 105 | self.executors[next_executor].add_operation( 106 | self._get_op(next_layer, op.microbatch) 107 | ) 108 | 109 | def _push_end_event(self, op, executor, end_time): 110 | self._pending_events.put(CompleteEvent(end_time, op, executor)) 111 | 112 | def schedule(self, **kwargs): 113 | n_microbatches = len(self.minibatch_spec.microbatches) 114 | status = self._init_executors(n_microbatches, **kwargs) 115 | if not status: 116 | return None 117 | self._inject_microbatches(0, n_microbatches) 118 | trace_events = self._get_trace_events() 119 | current_time = 0 120 | 121 | def __try_execute(): 122 | # priortize communication executors 123 | for executor in ( 124 | self.communication_executors + self.computation_executors 125 | ): 126 | end_time, launched_op, events = executor.try_execute( 127 | current_time 128 | ) 129 | if launched_op: 130 | self._push_end_event(launched_op, executor, end_time) 131 | trace_events["traceEvents"].extend(events) 132 | 133 | while True: 134 | __try_execute() 135 | if self._pending_events.empty(): 136 | break 137 | else: 138 | next_event = self._pending_events.get() 139 | current_time = next_event.completion_time 140 | ready_events = [next_event] 141 | while not self._pending_events.empty(): 142 | # try to process all events that finish at the same time 143 | another_event = self._pending_events.get() 144 | if another_event.completion_time <= current_time + 1e-6: 145 | ready_events.append(another_event) 146 | else: 147 | self._pending_events.put(another_event) 148 | break 149 | for event in ready_events: 150 | self._on_op_finish(event.executor, event.op) 151 | self.makespan = current_time 152 | # make sure all executors are empty 153 | for executor_idx, executor in self.executors.items(): 154 | if hasattr(executor, "available_queue"): 155 | assert len(executor.available_queue) == 0, ( 156 | f"Executor {executor_idx} has non-empty ready queue " 157 | f"at end of scheduling: {executor.available_queue}" 158 | ) 159 | if hasattr(executor, "next_op_idx"): 160 | assert executor.next_op_idx == len(executor.operator_order), ( 161 | f"Executor {executor_idx} has not finished all operations " 162 | f"at end of scheduling: {executor.available_queue}" 163 | ) 164 | return trace_events 165 | -------------------------------------------------------------------------------- /dynapipe/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import logging 5 | import os 6 | import sys 7 | from logging import Handler 8 | from typing import List 9 | 10 | _logging_lvl_map = { 11 | "DEBUG": logging.DEBUG, 12 | "INFO": logging.INFO, 13 | "WARNING": logging.WARNING, 14 | "ERROR": logging.ERROR, 15 | "CRITICAL": logging.CRITICAL, 16 | } 17 | _logging_lvl_from_env = os.environ.get("DYNAPIPE_DEBUG", "INFO") 18 | if _logging_lvl_from_env not in _logging_lvl_map: 19 | raise ValueError( 20 | f"Invalid logging level from env detected: {_logging_lvl_from_env}. " 21 | f"Valid options are {list(_logging_lvl_map.keys())}" 22 | ) 23 | _default_logging_level = _logging_lvl_map[_logging_lvl_from_env] 24 | 25 | if _default_logging_level == logging.DEBUG: 26 | debug_dir = os.environ.get("DYNAPIPE_LOGGING_DEBUG_DIR", None) 27 | if debug_dir is None: 28 | raise ValueError( 29 | "DYNAPIPE_LOGGING_DEBUG_DIR must be set when " 30 | "DYNAPIPE_DEBUG is set to DEBUG" 31 | ) 32 | # create output dir for executor logs 33 | os.makedirs(debug_dir, exist_ok=True) 34 | _debug_log_dir = debug_dir 35 | # create subdirs 36 | _dataloader_log_dir = os.path.join(_debug_log_dir, "dataloader") 37 | os.makedirs(_dataloader_log_dir, exist_ok=True) 38 | _preprocessing_log_dir = os.path.join(_debug_log_dir, "preprocessing") 39 | os.makedirs(_preprocessing_log_dir, exist_ok=True) 40 | _executor_log_dir = os.path.join(_debug_log_dir, "executor") 41 | os.makedirs(_executor_log_dir, exist_ok=True) 42 | _poller_log_dir = os.path.join(_debug_log_dir, "poller") 43 | os.makedirs(_poller_log_dir, exist_ok=True) 44 | else: 45 | # not used 46 | _debug_log_dir = "./" 47 | 48 | 49 | # modified from https://stackoverflow.com/a/56944256 50 | class DynaPipeFormatter(logging.Formatter): 51 | white = "\x1b[38;20m" 52 | grey = "\x1b[38;5;8m" 53 | yellow = "\x1b[33;20m" 54 | red = "\x1b[31;20m" 55 | bold_red = "\x1b[31;1m" 56 | reset = "\x1b[0m" 57 | 58 | color_mapping = { 59 | logging.DEBUG: grey, 60 | logging.INFO: white, 61 | logging.WARNING: yellow, 62 | logging.ERROR: red, 63 | logging.CRITICAL: bold_red, 64 | } 65 | 66 | def __init__(self, prefix=None, distributed_rank=None, colored=True): 67 | self.prefix = prefix 68 | self.distributed_rank = distributed_rank 69 | self.colored = colored 70 | 71 | def _get_fmt_colored(self, level): 72 | color = self.color_mapping[level] 73 | fmt = ( 74 | self.grey 75 | + "[%(asctime)s] " 76 | + self.reset 77 | + color 78 | + "[%(levelname)s] " 79 | + self.reset 80 | + self.grey 81 | + "[%(filename)s:%(lineno)d] " 82 | + self.reset 83 | ) 84 | fmt += color 85 | if self.prefix is not None: 86 | fmt += "[" + self.prefix + "] " 87 | if self.distributed_rank is not None: 88 | fmt += "[Rank " + str(self.distributed_rank) + "] " 89 | fmt += "%(message)s" 90 | fmt += self.reset 91 | return fmt 92 | 93 | def _get_fmt(self): 94 | fmt = "[%(asctime)s] [%(levelname)s] [%(filename)s:%(lineno)d] " 95 | if self.prefix is not None: 96 | fmt += "[" + self.prefix + "] " 97 | if self.distributed_rank is not None: 98 | fmt += "[Rank " + str(self.distributed_rank) + "] " 99 | fmt += "%(message)s" 100 | return fmt 101 | 102 | def format(self, record): 103 | if self.colored: 104 | log_fmt = self._get_fmt_colored(record.levelno) 105 | else: 106 | log_fmt = self._get_fmt() 107 | formatter = logging.Formatter(log_fmt) 108 | return formatter.format(record) 109 | 110 | 111 | # modified from https://stackoverflow.com/a/51612402 112 | class LoggerWriter(object): 113 | def __init__(self, writers): 114 | if not isinstance(writers, (list, tuple)): 115 | writers = (writers,) 116 | self._writers = writers 117 | self._msg: str = "" 118 | 119 | def write(self, message: str): 120 | self._msg = self._msg + message 121 | pos = self._msg.find("\n") 122 | while pos != -1: 123 | for writer in self._writers: 124 | writer(self._msg[:pos]) 125 | self._msg = self._msg[pos + 1 :] 126 | pos = self._msg.find("\n") 127 | 128 | def flush(self): 129 | if self._msg != "": 130 | for writer in self._writers: 131 | writer(self._msg) 132 | self._msg = "" 133 | 134 | 135 | # modified from DeepSpeed 136 | # deepspeed/utils/logging.py 137 | def create_logger( 138 | name=None, 139 | prefix=None, 140 | level=_default_logging_level, 141 | distributed_rank=None, 142 | log_file=None, 143 | ): 144 | """create a logger 145 | Args: 146 | name (str): name of the logger 147 | level: level of logger 148 | Raises: 149 | ValueError is name is None 150 | """ 151 | 152 | if name is None: 153 | raise ValueError("name for logger cannot be None") 154 | 155 | if level > logging.DEBUG: 156 | # disable log to file 157 | log_file = None 158 | 159 | formatter = DynaPipeFormatter( 160 | prefix=prefix, distributed_rank=distributed_rank, colored=False 161 | ) 162 | colored_formatter = DynaPipeFormatter( 163 | prefix=prefix, distributed_rank=distributed_rank, colored=True 164 | ) 165 | 166 | logger_ = logging.getLogger(name) 167 | # if handler already present, remove it 168 | if logger_.hasHandlers(): 169 | logger_.handlers.clear() 170 | logger_.setLevel(level) 171 | logger_.propagate = False 172 | handlers: List[Handler] = [] 173 | if log_file is not None: 174 | full_log_path = os.path.join(_debug_log_dir, log_file) 175 | handler = logging.FileHandler(filename=full_log_path) 176 | handler.setFormatter(formatter) 177 | handler.setLevel(level) 178 | handlers.append(handler) 179 | # also log all warnings/errors to stderr 180 | warn_handler = logging.StreamHandler(stream=sys.__stderr__) 181 | warn_handler.setLevel(logging.WARNING) 182 | warn_handler.setFormatter(colored_formatter) 183 | handlers.append(warn_handler) 184 | # copy stderr to log file 185 | sys.stderr = LoggerWriter([logger_.warning]) 186 | else: 187 | ch = logging.StreamHandler(stream=sys.stdout) 188 | ch.setLevel(level) 189 | ch.setFormatter(colored_formatter) 190 | handlers.append(ch) 191 | for handler in handlers: 192 | logger_.addHandler(handler) 193 | return logger_ 194 | 195 | 196 | logger = create_logger(name="DynaPipe", level=_default_logging_level) 197 | -------------------------------------------------------------------------------- /dynapipe/utils/memory_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from dataclasses import dataclass 5 | from typing import Tuple, Union 6 | 7 | from dynapipe.model import TransformerModelSpec 8 | 9 | MAX_POSSIBLE_MICROBATCH_SIZE = 2**31 - 1 10 | 11 | 12 | # utility functions to calculate transformer memory consumption 13 | def get_transformer_output_memory( 14 | sequence_length, batch_size, hidden_dim, bytes_per_element 15 | ): 16 | # size is in MB (megabytes) 17 | return sequence_length * batch_size * hidden_dim * bytes_per_element / 1e6 18 | 19 | 20 | def get_transformer_activation( 21 | sequence_length, 22 | batch_size, 23 | hidden_dim, 24 | num_attn_heads, 25 | mlp_hidden_dim, 26 | bytes_per_element, 27 | tp_size, 28 | is_decoder=False, 29 | ): 30 | # Estimates the activation memory needed for a transformer layer. 31 | # Size is in MB (megabytes) 32 | # Formula from Korthikanti et.al, 33 | # "Reducing Activation Recomputation in Large Transformer Models" 34 | # https://arxiv.org/abs/2205.05198 35 | result = 0 36 | sbh = sequence_length * batch_size * hidden_dim 37 | sbh_mlp = sequence_length * batch_size * mlp_hidden_dim 38 | as2b = num_attn_heads * sequence_length * sequence_length * batch_size 39 | attention = 0 40 | # QKV 41 | attention_input = sbh * bytes_per_element 42 | # QK^T 43 | attention += 2 * sbh * bytes_per_element 44 | # softmax 45 | attention += as2b * bytes_per_element 46 | # softmax dropout (one byte per element) 47 | attention += as2b 48 | # attn over values 49 | attention += (as2b + sbh) * bytes_per_element 50 | attention /= tp_size 51 | # attention input is not parallelised in tensor parallelism 52 | attention += attention_input 53 | 54 | result += attention 55 | # MLP (assume MLP hidden dim is same as hidden dim) 56 | # MLP input is not parallelised in tensor parallelism, 57 | # i.e., sbh * bytes_per_element 58 | # other parts are parallelised 59 | result += ( 60 | sbh * bytes_per_element 61 | + (2 * sbh_mlp * bytes_per_element + sbh) / tp_size 62 | ) 63 | # layernorm 64 | result += 2 * sbh * bytes_per_element 65 | if is_decoder: 66 | # cross-attention 67 | result += attention 68 | # encoder input 69 | result += sbh * bytes_per_element 70 | return result / 1e6 71 | 72 | 73 | def get_transformer_model_state( 74 | hidden_dim, 75 | num_attn_heads, 76 | kv_channels, 77 | mlp_hidden_dim, 78 | bytes_per_element, 79 | optimizer_state_multiplier, 80 | tp_size, 81 | is_decoder=False, 82 | ): 83 | # Estimate transformer model state usage, assuming Adam optimizer 84 | # and fp16 training is used. Size is in MB (megabytes) 85 | # Note: optimizer state multiplier should already consider 86 | # the number bytes needed to store each element. 87 | # bytes_per_element is only used to calculate the size of 88 | # the model parameters and gradients. 89 | # Optimizer state multiplier is 12 for FP16 mixed precision Adam. 90 | # Reference: Rajbhandari et.al., 91 | # "ZeRO: Memory Optimizations Toward Training Trillion Parameter Models" 92 | # https://arxiv.org/abs/1910.02054 93 | n_params = 0 94 | attention = 0 95 | # layer norm x2 96 | n_params += 2 * 2 * hidden_dim 97 | # QKV 98 | attention += 3 * (hidden_dim * kv_channels + kv_channels) 99 | # projection 100 | attention += num_attn_heads * kv_channels * hidden_dim + hidden_dim 101 | attention /= tp_size 102 | n_params += attention 103 | # MLP 104 | n_params += ( 105 | 2 * hidden_dim * mlp_hidden_dim + hidden_dim + mlp_hidden_dim 106 | ) / tp_size 107 | if is_decoder: 108 | # cross-attention 109 | n_params += attention 110 | # layer norm 111 | n_params += 2 * hidden_dim 112 | # scale to model state 113 | result = n_params * ( 114 | bytes_per_element + bytes_per_element + optimizer_state_multiplier 115 | ) 116 | return result / 1e6 117 | 118 | 119 | @dataclass 120 | class TransformerMemoryModel(object): 121 | # See comments for get_transformer_activation and 122 | # get_transformer_model_state for details 123 | model_spec: TransformerModelSpec 124 | 125 | def get_output_memory(self, batch_size, sequence_length): 126 | return get_transformer_output_memory( 127 | sequence_length, 128 | batch_size, 129 | self.model_spec.hidden_dim, 130 | bytes_per_element=self.model_spec.bytes_per_element, 131 | ) 132 | 133 | def get_activation_memory( 134 | self, batch_size, sequence_length, is_decoder=False 135 | ): 136 | return get_transformer_activation( 137 | sequence_length, 138 | batch_size, 139 | self.model_spec.hidden_dim, 140 | self.model_spec.num_attn_heads, 141 | mlp_hidden_dim=self.model_spec.mlp_hidden_dim, 142 | bytes_per_element=self.model_spec.bytes_per_element, 143 | tp_size=self.model_spec.tp_size, 144 | is_decoder=is_decoder, 145 | ) 146 | 147 | def get_model_state_memory(self, is_decoder=False): 148 | return get_transformer_model_state( 149 | self.model_spec.hidden_dim, 150 | self.model_spec.num_attn_heads, 151 | kv_channels=self.model_spec.kv_channels, 152 | mlp_hidden_dim=self.model_spec.mlp_hidden_dim, 153 | bytes_per_element=self.model_spec.bytes_per_element, 154 | optimizer_state_multiplier=self.model_spec.optimizer_state_multiplier, # noqa 155 | tp_size=self.model_spec.tp_size, 156 | is_decoder=is_decoder, 157 | ) 158 | 159 | 160 | @dataclass 161 | class InvTransformerMemoryModel: 162 | n_encoders: int 163 | n_decoders: int 164 | model_spec: TransformerModelSpec 165 | _mem_model: Union[TransformerMemoryModel, None] = None 166 | 167 | def __post_init__(self): 168 | if self._mem_model is None: 169 | self._mem_model = TransformerMemoryModel( 170 | model_spec=self.model_spec, 171 | ) 172 | 173 | def _get_memory(self, mbs: int, seq_len: int) -> float: 174 | return self.n_encoders * self._mem_model.get_activation_memory( 175 | mbs, seq_len 176 | ) + self.n_decoders * self._mem_model.get_activation_memory( 177 | mbs, seq_len, is_decoder=True 178 | ) 179 | 180 | def set_max_memory(self, max_memory: float) -> None: 181 | self._ref_memory = max_memory 182 | 183 | def set_reference(self, mbs, seq_len) -> None: 184 | self._ref_memory = self._get_memory(mbs, seq_len) 185 | 186 | def _get_mbs_within_range( 187 | self, seq_len: int, mbs_range: Tuple[int, int] 188 | ) -> int: 189 | if mbs_range[1] >= mbs_range[0]: 190 | midpoint = (mbs_range[0] + mbs_range[1]) // 2 191 | mid_memory = self._get_memory(midpoint, seq_len) 192 | mid_plus_one_memory = self._get_memory(midpoint + 1, seq_len) 193 | if ( 194 | mid_memory <= self._ref_memory 195 | and mid_plus_one_memory >= self._ref_memory 196 | ): 197 | return midpoint 198 | elif mid_memory <= self._ref_memory: 199 | return self._get_mbs_within_range( 200 | seq_len, (midpoint + 1, mbs_range[1]) 201 | ) 202 | else: 203 | return self._get_mbs_within_range( 204 | seq_len, (mbs_range[0], midpoint - 1) 205 | ) 206 | else: 207 | return -1 208 | 209 | def get_microbatch_size(self, sequence_length): 210 | assert hasattr( 211 | self, "_ref_memory" 212 | ), "Must set memory reference or max memory first." 213 | return self._get_mbs_within_range( 214 | sequence_length, (1, MAX_POSSIBLE_MICROBATCH_SIZE) 215 | ) 216 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | elkai @ git+https://github.com/chenyu-jiang/elkai.git@blkh 2 | numpy 3 | prtpy 4 | pybind11 5 | pytest 6 | redis 7 | scikit-learn>=1.2.0 8 | scipy 9 | setuptools 10 | sortedcontainers 11 | tqdm 12 | -------------------------------------------------------------------------------- /scripts/check_dataloader_logs.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import argparse 5 | import os 6 | from collections import defaultdict 7 | 8 | # Parse the last generated EP from the planner logs for debugging purposes 9 | 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("log_dir", type=str, help="Path to the log directory") 14 | return parser.parse_args() 15 | 16 | 17 | def get_last_generated_ep(planner_logs): 18 | all_generated_eps = [] 19 | last_ep_per_worker = {} 20 | for worker_id, lines in planner_logs.items(): 21 | for line in lines: 22 | if "Pushing EP" in line: 23 | iteration = int(line.split()[-4]) 24 | all_generated_eps.append(iteration) 25 | last_ep_per_worker[worker_id] = iteration 26 | max_iter = max(all_generated_eps) 27 | assert sorted(all_generated_eps) == list(range(max_iter + 1)) 28 | return max_iter, last_ep_per_worker 29 | 30 | 31 | def get_last_received_ep(worker_logs): 32 | all_received_eps = [] 33 | last_ep_per_worker = {} 34 | for worker_id, lines in worker_logs.items(): 35 | for line in lines: 36 | if "Got data for" in line: 37 | iteration = int(line.split()[-1][:-1]) 38 | all_received_eps.append(iteration) 39 | last_ep_per_worker[worker_id] = iteration 40 | max_iter = max(all_received_eps) 41 | assert sorted(all_received_eps) == list(range(max_iter + 1)) 42 | return max_iter, last_ep_per_worker 43 | 44 | 45 | def parse_rank_from_log_path(log_path): 46 | log_path = os.path.basename(log_path).split(".")[0] 47 | rank = int(log_path.split("_")[0][1:]) 48 | virtual_rank = int(log_path.split("_")[1][2:]) 49 | worker_id = int(log_path.split("_")[2][1:]) 50 | return rank, virtual_rank, worker_id 51 | 52 | 53 | def main(args): 54 | # read dataloader logs 55 | dataloader_log_paths = os.listdir(os.path.join(args.log_dir, "dataloader")) 56 | dataloader_log_paths = [ 57 | os.path.join(args.log_dir, "dataloader", log) 58 | for log in dataloader_log_paths 59 | ] 60 | dataloader_logs = {} 61 | for log in dataloader_log_paths: 62 | rank, virtual_rank, worker_id = parse_rank_from_log_path(log) 63 | with open(log, "r") as f: 64 | dataloader_logs[(rank, virtual_rank, worker_id)] = f.readlines() 65 | planner_logs = {} 66 | worker_logs = defaultdict(dict) 67 | for rank, virtual_rank, worker_id in dataloader_logs.keys(): 68 | if rank == 0 and virtual_rank == 0: 69 | planner_logs[worker_id] = dataloader_logs[ 70 | (rank, virtual_rank, worker_id) 71 | ] 72 | else: 73 | worker_logs[(rank, virtual_rank)][worker_id] = dataloader_logs[ 74 | (rank, virtual_rank, worker_id) 75 | ] 76 | last_generated_ep, _ = get_last_generated_ep(planner_logs) 77 | print(f"Last generated EP on planner: {last_generated_ep}") 78 | for (rank, virtual_rank), worker_dict in worker_logs.items(): 79 | last_received_ep, _ = get_last_received_ep(worker_dict) 80 | print( 81 | f"Last received EP on worker r{rank}-vr{virtual_rank}: " 82 | f"{last_received_ep}" 83 | ) 84 | 85 | 86 | if __name__ == "__main__": 87 | main(parse_args()) 88 | -------------------------------------------------------------------------------- /scripts/check_executor_logs.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import argparse 5 | from datetime import datetime 6 | 7 | # Checks executor logs for slow instructions and mismatched instructions, 8 | # for debugging purposes 9 | 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("log_paths", nargs="+", default=[]) 14 | return parser.parse_args() 15 | 16 | 17 | def main(args): 18 | for log in args.log_paths: 19 | rank = log.split("_")[-1].split(".")[0] 20 | mismatched_instrs = {} 21 | last_time = None 22 | with open(log, "r") as f: 23 | for line in f: 24 | datetime_str = line.split(" ")[1].split(",")[0] 25 | time = datetime.strptime(datetime_str, "%H:%M:%S") 26 | if last_time is not None: 27 | timedelta = time - last_time 28 | if timedelta.total_seconds() > 1: 29 | print( 30 | f"Rank {rank} has slow instr " 31 | f"({timedelta.total_seconds()} seconds):\n\t{line}" 32 | ) 33 | last_time = time 34 | if "Executing instruction:" in line: 35 | instr = line.split(":")[-1].split("\x1b")[0].strip() 36 | if instr not in mismatched_instrs: 37 | mismatched_instrs[instr] = 1 38 | else: 39 | mismatched_instrs[instr] += 1 40 | elif "finished" in line: 41 | instr = ( 42 | line.split(":")[-1] 43 | .split("finished")[0] 44 | .split("\x1b")[0] 45 | .strip() 46 | ) 47 | mismatched_instrs[instr] -= 1 48 | assert mismatched_instrs[instr] >= 0 49 | if mismatched_instrs[instr] == 0: 50 | del mismatched_instrs[instr] 51 | for instr, cnt in mismatched_instrs.items(): 52 | print( 53 | f"Rank {rank} has mismatched instructions: " 54 | f"{instr}, repeat cnt: {cnt}" 55 | ) 56 | 57 | 58 | if __name__ == "__main__": 59 | main(parse_args()) 60 | -------------------------------------------------------------------------------- /scripts/estimate_memory_usage.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import pandas as pd 5 | 6 | from dynapipe.utils.memory_utils import TransformerMemoryModel 7 | 8 | # Scripts for plotting memory usage of transformer models with different 9 | # sequence lengths abd microbatch sizes. 10 | 11 | 12 | def main(): 13 | hidden_dim = 8192 14 | num_attn_heads = 64 15 | kv_channels = 128 16 | mlp_hidden_dim = 6560 17 | bytes_per_element = 2 18 | tp_size = 8 19 | 20 | prod_ = 1024 * 12 21 | js = [] 22 | for enc_sec_len in [512, 768, 1024, 1536, 2048]: 23 | dec_sec_len = enc_sec_len / 2 24 | mbs = prod_ / enc_sec_len 25 | enc_mem_model = TransformerMemoryModel( 26 | batch_size=mbs, 27 | sequence_length=enc_sec_len, 28 | hidden_dim=hidden_dim, 29 | num_attn_heads=num_attn_heads, 30 | mlp_hidden_dim=mlp_hidden_dim, 31 | kv_channels=kv_channels, 32 | bytes_per_element=bytes_per_element, 33 | optimizer_state_multiplier=12, 34 | tp_size=tp_size, 35 | ) 36 | dec_mem_model = TransformerMemoryModel( 37 | batch_size=mbs, 38 | sequence_length=dec_sec_len, 39 | hidden_dim=hidden_dim, 40 | num_attn_heads=num_attn_heads, 41 | mlp_hidden_dim=mlp_hidden_dim, 42 | kv_channels=kv_channels, 43 | bytes_per_element=bytes_per_element, 44 | optimizer_state_multiplier=12, 45 | tp_size=tp_size, 46 | is_decoder=True, 47 | ) 48 | j = { 49 | "enc_seq_len": enc_sec_len, 50 | "dec_seq_len": dec_sec_len, 51 | "mbs": mbs, 52 | "enc_per_layer_output_memory": enc_mem_model.get_output_memory(), 53 | "enc_per_layer_activation_memory": enc_mem_model.get_activation_memory(), # noqa 54 | "enc_per_layer_model_state_memory": enc_mem_model.get_model_state_memory(), # noqa 55 | "dec_per_layer_output_memory": dec_mem_model.get_output_memory(), # noqa 56 | "dec_per_layer_activation_memory": dec_mem_model.get_activation_memory(), # noqa 57 | "dec_per_layer_model_state_memory": dec_mem_model.get_model_state_memory(), # noqa 58 | } 59 | js.append(j) 60 | 61 | df = pd.DataFrame(js) 62 | df.to_csv("out.csv", index=False) 63 | print(df) 64 | return df 65 | 66 | 67 | if __name__ == "__main__": 68 | main() 69 | -------------------------------------------------------------------------------- /scripts/generate_random_mb_spec.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import argparse 5 | import json 6 | 7 | import numpy as np 8 | 9 | # Generate random microbatch specifications using specified distribution 10 | 11 | 12 | def parse_arguments(): 13 | # Parse arguments 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument( 16 | "-m", 17 | "--microbatches", 18 | type=int, 19 | required=True, 20 | help="Number of microbatches.", 21 | ) 22 | parser.add_argument( 23 | "-s", "--stages", type=int, required=True, help="Number of stages." 24 | ) 25 | parser.add_argument( 26 | "-d", 27 | "--device-assignment", 28 | type=str, 29 | required=True, 30 | help="Assignment of stages to devices.", 31 | ) 32 | parser.add_argument( 33 | "-o", 34 | "--output", 35 | type=str, 36 | help="Path where the output config json file will be stored.", 37 | ) 38 | parser.add_argument( 39 | "--distribution", 40 | type=str, 41 | choices=["normal", "geometric"], 42 | help="Distribution used to generate the multipliers.", 43 | ) 44 | parser.add_argument( 45 | "-p", 46 | "--success-probability", 47 | type=float, 48 | default=0.5, 49 | help="Success probability for the geometric distribution.", 50 | ) 51 | parser.add_argument( 52 | "-std", 53 | "--stddev", 54 | type=float, 55 | default=1.0, 56 | help="Standard deviation of the microbatch multipliers.", 57 | ) 58 | 59 | args = parser.parse_args() 60 | args.device_assignment = [ 61 | int(x) for x in args.device_assignment.split(",") 62 | ] 63 | if args.output is None: 64 | args.output = "mb_spec" 65 | if args.distribution == "geometric": 66 | dist_spec = "-p" + str(args.success_probability) 67 | else: 68 | dist_spec = "-std" + str(args.stddev) 69 | args.output += f"_{args.distribution}{dist_spec}" 70 | args.output += ( 71 | f"_m{args.microbatches}s{args.stages}p{args.stddev}.json" 72 | ) 73 | return args 74 | 75 | 76 | if __name__ == "__main__": 77 | args = parse_arguments() 78 | 79 | # Load model specification 80 | base_exec_times = [1000] * args.stages 81 | microbatch_multipliers = [] 82 | for m in range(args.microbatches): 83 | mult = 0 84 | while mult <= 0: 85 | if args.distribution == "normal": 86 | mult = np.random.normal(1, args.stddev) 87 | elif args.distribution == "geometric": 88 | mult = np.random.geometric(args.success_probability) 89 | else: 90 | raise ValueError( 91 | "Unknown distribution: {}".format(args.distribution) 92 | ) 93 | microbatch_multipliers.append([mult] * args.stages) 94 | per_stage_device_assignment = args.device_assignment 95 | per_device_stage_assignment = [ 96 | [] for _ in range(max(args.device_assignment) + 1) 97 | ] 98 | for i, d in enumerate(per_stage_device_assignment): 99 | per_device_stage_assignment[d].append(i) 100 | 101 | json_dict = { 102 | "base_exec_times": base_exec_times, 103 | "per_microbatch_multiplier": microbatch_multipliers, 104 | "per_device_stage_assignment": per_device_stage_assignment, 105 | } 106 | with open(args.output, "w") as f: 107 | json.dump(json_dict, f, indent=4) 108 | -------------------------------------------------------------------------------- /scripts/plot_microbatch_permutation.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import argparse 5 | import json 6 | import os 7 | 8 | from dynapipe.model import SchedulerMinibatchSpec 9 | from dynapipe.schedule_opt import get_available_schedulers, get_scheduler_class 10 | 11 | # Reads in a microbatch specification json file, generate the resulting 12 | # timeline after permutating the microbatches. 13 | 14 | 15 | def parse_arguments(): 16 | # Parse arguments 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument( 19 | "-m", 20 | "--microbatch-spec", 21 | type=str, 22 | required=True, 23 | help="Path to microbatch specification json file.", 24 | ) 25 | parser.add_argument( 26 | "-o", 27 | "--output", 28 | type=str, 29 | help="Path where the output timeline json file will be stored." 30 | "Default to {microbatch_spec_name}" 31 | "_{schedule_simulator}_timeline.json.", 32 | ) 33 | parser.add_argument( 34 | "-p", 35 | "--microbatch-permutation", 36 | type=str, 37 | help="Reorder the microbatches according to the permutation.", 38 | ) 39 | parser.add_argument( 40 | "--schedule-simulator", 41 | type=str, 42 | choices=get_available_schedulers(), 43 | default="wait-free-cyclic", 44 | help="The schedule simulator to use. Defaults to wait-free-cyclic.", 45 | ) 46 | 47 | args = parser.parse_args() 48 | if args.output is None: 49 | mbspec_basename = os.path.basename(args.microbatch_spec).rsplit( 50 | ".", 1 51 | )[0] 52 | args.output = ( 53 | mbspec_basename + f"_{args.schedule_simulator}" + "_timeline.json" 54 | ) 55 | if args.microbatch_permutation is not None: 56 | args.microbatch_permutation = [ 57 | int(x) for x in args.microbatch_permutation.split(",") 58 | ] 59 | return args 60 | 61 | 62 | if __name__ == "__main__": 63 | args = parse_arguments() 64 | 65 | # Load model specification 66 | with open(args.microbatch_spec, "r") as f: 67 | microbatch_spec = json.load(f) 68 | base_exec_times = microbatch_spec["base_exec_times"] 69 | per_microbatch_multiplier = microbatch_spec["per_microbatch_multiplier"] 70 | if args.microbatch_permutation is not None: 71 | assert len(args.microbatch_permutation) == len( 72 | per_microbatch_multiplier 73 | ), ( 74 | "The length of the microbatch permutation " 75 | "({}) must be the same as the number " 76 | "of microbatches ({}).".format( 77 | len(args.microbatch_permutation), len(base_exec_times) 78 | ) 79 | ) 80 | new_per_microbatch_multiplier = [] 81 | for i in range(len(per_microbatch_multiplier)): 82 | new_per_microbatch_multiplier.append( 83 | per_microbatch_multiplier[args.microbatch_permutation[i]] 84 | ) 85 | per_microbatch_multiplier = new_per_microbatch_multiplier 86 | per_device_stage_assignment = microbatch_spec[ 87 | "per_device_stage_assignment" 88 | ] 89 | per_stage_device_assignment = [-1] * len(base_exec_times) 90 | for dev, stages in enumerate(per_device_stage_assignment): 91 | for stage in stages: 92 | per_stage_device_assignment[stage] = dev 93 | for stage, dev in enumerate( 94 | per_stage_device_assignment[: len(per_stage_device_assignment) // 2] 95 | ): 96 | assert dev != -1, f"Stage {stage} is not assigned to any device." 97 | assert dev == per_stage_device_assignment[-stage - 1], ( 98 | f"FW stage {stage} is assigned to device {dev} " 99 | f"but BW stage {-stage - 1} is assigned to " 100 | f"device {per_stage_device_assignment[-stage - 1]}." 101 | ) 102 | 103 | # get scheduler 104 | fw_len = len(base_exec_times) // 2 105 | scheduler_params = SchedulerMinibatchSpec( 106 | base_exec_times[:fw_len], 107 | [0] * (fw_len - 1), 108 | [1000] * fw_len, 109 | [1000] * fw_len, 110 | per_stage_device_assignment[:fw_len], 111 | base_exec_times[fw_len:], 112 | bw_comm_times=[0] * (fw_len - 1), 113 | ) 114 | scheduler_class = get_scheduler_class(args.schedule_simulator) 115 | simulator = scheduler_class(scheduler_params, separate_comm_stage=False) 116 | 117 | # run simulation 118 | timeline_json = simulator.schedule( 119 | len(per_microbatch_multiplier), 120 | microbatch_multiplier=per_microbatch_multiplier, 121 | ) 122 | print("# Makespan: {} ms".format(simulator.get_makespan() / 1000.0)) 123 | # save timeline 124 | with open(args.output, "w") as f: 125 | json.dump(timeline_json, f, indent=4) 126 | -------------------------------------------------------------------------------- /scripts/pre-commit.hook: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ^ Note the above "shebang" line. This says "This is an executable shell script" 3 | # Name this script "pre-commit" and place it in the ".git/hooks/" directory 4 | 5 | # If any command fails, exit immediately with that command's exit status 6 | set -eo pipefail 7 | 8 | # Run isort against all code 9 | isort . --profile black --line-length 79 10 | echo "isort passed!" 11 | 12 | # Run black against all code 13 | black . --line-length 79 --check 14 | echo "black passed!" 15 | 16 | # Run flake8 against all code in the `source_code` directory 17 | flake8 . --extend-ignore F405,E203 18 | echo "flake8 passed!" -------------------------------------------------------------------------------- /scripts/run_format.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | 6 | # If any command fails, exit immediately with that command's exit status 7 | set -eo pipefail 8 | 9 | # Run isort against all code 10 | isort . --profile black --line-length 79 11 | echo "isort on code passed!" 12 | 13 | # Run black against all code 14 | black . --line-length 79 15 | echo "black on code passed!" 16 | 17 | # Run flake8 against all code in the `source_code` directory 18 | flake8 . --extend-ignore F405,E203 19 | echo "flake8 on code passed!" 20 | 21 | clang-format -i ./dynapipe/data_opt/dp_helper.cpp -------------------------------------------------------------------------------- /scripts/simulation/schedule_under_dynamic_mb.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import os 5 | 6 | import numpy as np 7 | from shift_trace_json import ( 8 | construct_exec_time_dict, 9 | convert_to_multistream_comm, 10 | ) 11 | from tqdm import tqdm 12 | 13 | from dynapipe.model import ( 14 | DynaPipeMicrobatch, 15 | DynaPipeMinibatch, 16 | get_uniform_cluster, 17 | ) 18 | from dynapipe.schedule_opt.execution_planner import optimize_schedule 19 | 20 | 21 | def get_hetero_minibatch( 22 | microbatch_multiplier, comm_factor=1 23 | ) -> DynaPipeMinibatch: 24 | fw_times = [4000] * 16 25 | memory_multiplier = microbatch_multiplier 26 | microbatches = [] 27 | for i in range(len(microbatch_multiplier)): 28 | current_fw_times = [ 29 | fw_times[j] * microbatch_multiplier[i] 30 | for j in range(len(fw_times)) 31 | ] 32 | current_bw_times = [2 * t for t in current_fw_times] 33 | microbatch = DynaPipeMicrobatch(str(i)) 34 | microbatch.set_fw_exec_times(current_fw_times) 35 | microbatch.set_bw_exec_times(current_bw_times) 36 | microbatch.set_fw_comm_size( 37 | [200 * comm_factor * microbatch_multiplier[i]] 38 | * (len(fw_times) - 1) 39 | ) 40 | microbatch.set_bw_comm_size( 41 | [200 * comm_factor * microbatch_multiplier[i]] 42 | * (len(fw_times) - 1) 43 | ) 44 | microbatch.set_model_state_memory([4000] * len(fw_times)) 45 | microbatch.set_model_stored_activation_memory( 46 | [8000 * memory_multiplier[i]] * len(fw_times) 47 | ) 48 | microbatch.set_model_peak_activation_memory( 49 | [16000 * memory_multiplier[i]] * len(fw_times) 50 | ) 51 | microbatch.set_activation_shapes( 52 | [[(64, 128, 512)]] * (len(fw_times) // 2) 53 | + [[(64, 128, 512), (64, 128, 512)]] * (len(fw_times) // 2) 54 | ) 55 | microbatches.append(microbatch) 56 | minibatch = DynaPipeMinibatch("test", microbatches) 57 | return minibatch 58 | 59 | 60 | def gen_micro_batch_multipliers(n_iters, n_microbatches, std): 61 | rng = np.random.default_rng(seed=48) 62 | for _ in range(n_iters): 63 | m = np.clip(rng.normal(1, std, size=n_microbatches), 0.1, 10) 64 | normalized_m = m / (sum(m) / n_microbatches) 65 | yield normalized_m 66 | 67 | 68 | def schedule_minibatch( 69 | n_stages, 70 | sch_type, 71 | n_iters, 72 | n_microbatches=16, 73 | std=0.1, 74 | multistream=False, 75 | comm_factor=1, 76 | ): 77 | nlayers = 16 78 | assert nlayers % n_stages == 0 79 | layers_per_stage = nlayers // n_stages 80 | device_assignment = [] 81 | for i in range(n_stages): 82 | device_assignment += [i] * layers_per_stage 83 | cluster = get_uniform_cluster(n_stages, intra_node_bw=1e6) 84 | 85 | if sch_type == "1F1B": 86 | try_permutations = False 87 | else: 88 | try_permutations = True 89 | makespans = [] 90 | for multiplier in gen_micro_batch_multipliers( 91 | n_iters, n_microbatches, std 92 | ): 93 | # multiplier = [1] * 16 94 | if multistream: 95 | try_permutations = False 96 | minibatch = get_hetero_minibatch(multiplier, comm_factor=comm_factor) 97 | ( 98 | _, 99 | _, 100 | _, 101 | min_makespan, 102 | min_stats, 103 | min_instructions, 104 | ) = optimize_schedule( 105 | sch_type, 106 | minibatch, 107 | cluster, 108 | device_assignment, 109 | try_permutations=try_permutations, 110 | include_memory_stats=True, 111 | progress_bar=False, 112 | memory_limit=float("inf"), 113 | ) 114 | if not multistream: 115 | makespans.append(min_makespan) 116 | continue 117 | # produce a reference trace 118 | ref_multiplier = [1] * 16 119 | ref_minibatch = get_hetero_minibatch(ref_multiplier) 120 | ( 121 | _, 122 | _, 123 | _, 124 | ref_min_makespan, 125 | ref_min_stats, 126 | ref_min_instructions, 127 | ) = optimize_schedule( 128 | sch_type, 129 | ref_minibatch, 130 | cluster, 131 | device_assignment, 132 | try_permutations=False, 133 | include_memory_stats=True, 134 | progress_bar=False, 135 | memory_limit=float("inf"), 136 | ) 137 | sch_trace = min_stats[-1] 138 | ref_trace = ref_min_stats[-1] 139 | sch_time_dict = construct_exec_time_dict(sch_trace) 140 | multistream_ref_trace, ref_makespan = convert_to_multistream_comm( 141 | ref_trace, sch_time_dict 142 | ) 143 | multistream_sch_trace, makespan = convert_to_multistream_comm( 144 | sch_trace 145 | ) 146 | makespans.append((makespan, ref_makespan)) 147 | # trace_path = "test_wfcyclic_trace.json" 148 | # with open(trace_path, "w") as f: 149 | # import json 150 | # json.dump(min_stats[-1], f) 151 | return np.mean(np.array(makespans), axis=0) 152 | 153 | 154 | if __name__ == "__main__": 155 | exists = os.path.isfile("./compare_schedule_multistream.csv") 156 | with open("./compare_schedule_multistream.csv", "a") as f: 157 | if not exists: 158 | f.write( 159 | "n_stages,sch_type,std,comm_factor,makespan,ref_makespan\n" 160 | ) 161 | # for n_stages in tqdm([2, 4, 8, 16]): 162 | for n_stages in tqdm([16]): 163 | # for sch_type in tqdm(["1F1B", "wait-free-cyclic"], leave=False): 164 | for sch_type in tqdm(["wait-free-cyclic"], leave=False): 165 | for std in tqdm( 166 | [0.1, 0.2, 0.4, 0.8, 1.6, 3.2, 6.4], leave=False 167 | ): 168 | for comm_factor in [0.2, 0.4, 0.8, 1.0, 1.2, 1.6, 3.2]: 169 | makespan, ref_makespan = schedule_minibatch( 170 | n_stages, 171 | sch_type, 172 | 1, 173 | std=std, 174 | multistream=True, 175 | comm_factor=comm_factor, 176 | ) 177 | print( 178 | f"{n_stages},{sch_type},{std}," 179 | f"{comm_factor},{makespan},{ref_makespan}" 180 | ) 181 | f.write( 182 | f"{n_stages},{sch_type},{std}," 183 | f"{comm_factor},{makespan},{ref_makespan}\n" 184 | ) 185 | f.flush() 186 | -------------------------------------------------------------------------------- /scripts/validate_execution_plans.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import argparse 5 | import pickle 6 | 7 | from dynapipe.pipe.utils import check_deadlock 8 | 9 | # Validate execution plans by checking for deadlocks 10 | 11 | 12 | def parse_args(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("execution_plan", type=str) 15 | return parser.parse_args() 16 | 17 | 18 | def main(args): 19 | with open(args.execution_plan, "rb") as f: 20 | eps = pickle.load(f) 21 | check_deadlock(eps) 22 | 23 | 24 | if __name__ == "__main__": 25 | main(parse_args()) 26 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from pybind11.setup_helpers import build_ext, intree_extensions 5 | from setuptools import find_packages, setup 6 | 7 | ext_modules = intree_extensions( 8 | ["dynapipe/data_opt/dp_helper.cpp"], 9 | ) 10 | 11 | setup( 12 | name="dynapipe", 13 | version="0.0.1", 14 | packages=find_packages(), 15 | cmdclass={"build_ext": build_ext}, 16 | ext_modules=ext_modules, 17 | ) 18 | -------------------------------------------------------------------------------- /tests/test_dataloader/test_dataloader.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # Note: this test requires torch 5 | # to run this test, exec: 6 | # If running hanging tests or multi-node tests: 7 | # DYNAPIPE_DEBUG=DEBUG DYNAPIPE_LOGGING_DEBUG_DIR=./test_debug \ 8 | # torchrun --standalone --nnodes=1 --nproc_per_node=4 test_dataloader.py 9 | # Others: 10 | # DYNAPIPE_DEBUG=DEBUG DYNAPIPE_LOGGING_DEBUG_DIR=./test_debug \ 11 | # torchrun --standalone --nnodes=1 --nproc_per_node=2 test_dataloader.py 12 | 13 | import os 14 | 15 | import pytest 16 | import torch 17 | import torch.distributed as dist 18 | from torch.utils.data import Dataset 19 | 20 | from dynapipe.model import TransformerModelSpec, get_uniform_cluster 21 | from dynapipe.pipe.data_loader import DynaPipeDataLoader, TrainingSpec 22 | from dynapipe.pipe.instructions import ExecutionPlan, ForwardPass 23 | 24 | torch.manual_seed(42) 25 | 26 | 27 | @pytest.fixture(scope="module", autouse=True) 28 | def init_torch_distributed(): 29 | os.environ["MASTER_ADDR"] = "localhost" 30 | os.environ["MASTER_PORT"] = "12355" 31 | torch.distributed.init_process_group("gloo") 32 | 33 | 34 | class DummyDataset(Dataset): 35 | def __init__(self, size, inputs_only=False): 36 | self.size = size 37 | torch.manual_seed(42) 38 | # pre-generate all data 39 | self.enc_seqlen = [] 40 | self.dec_seqlen = [] 41 | self.data = [] 42 | for _ in range(size): 43 | enc_seqlen, dec_seqlen = torch.randint(24, 512, (2,)) 44 | self.enc_seqlen.append(enc_seqlen) 45 | if not inputs_only: 46 | self.dec_seqlen.append(dec_seqlen) 47 | result = { 48 | "text_enc": list( 49 | torch.randint(0, 100, (enc_seqlen,)).numpy() 50 | ), 51 | "text_dec": list( 52 | torch.randint(0, 100, (dec_seqlen,)).numpy() 53 | ), 54 | } 55 | else: 56 | result = { 57 | "text": list(torch.randint(0, 100, (enc_seqlen,)).numpy()), 58 | } 59 | self.data.append(result) 60 | 61 | def __len__(self): 62 | return self.size 63 | 64 | def __getitem__(self, index): 65 | return self.data[index] 66 | 67 | 68 | def dummy_pack_fn(tensors): 69 | # (input, extra) 70 | if len(tensors) == 0: 71 | return [], 0 72 | if isinstance(tensors[0], list): 73 | concated_list = [] 74 | for t in tensors: 75 | concated_list.extend(t) 76 | return concated_list, 0 77 | return torch.cat(tensors, dim=0), 0 78 | 79 | 80 | def dummy_constructor_fn( 81 | encoder_input, 82 | encoder_extra, 83 | decoder_input, 84 | decoder_extra, 85 | encoder_seqlen, 86 | decoder_seqlen, 87 | ): 88 | encoder_padding_len = encoder_seqlen - len(encoder_input) 89 | if decoder_input is not None: 90 | decoder_padding_len = decoder_seqlen - len(decoder_input) 91 | encoder_input = torch.tensor(encoder_input, dtype=torch.long) 92 | if decoder_input is not None: 93 | decoder_input = torch.tensor(decoder_input, dtype=torch.long) 94 | encoder_padded = torch.cat( 95 | [ 96 | encoder_input, 97 | torch.zeros( 98 | encoder_padding_len, 99 | dtype=encoder_input.dtype, 100 | device=encoder_input.device, 101 | ), 102 | ], 103 | dim=0, 104 | ) 105 | if decoder_input is not None: 106 | decoder_padded = torch.cat( 107 | [ 108 | decoder_input, 109 | torch.zeros( 110 | decoder_padding_len, 111 | dtype=decoder_input.dtype, 112 | device=decoder_input.device, 113 | ), 114 | ], 115 | dim=0, 116 | ) 117 | return { 118 | "text_enc": encoder_padded, 119 | "text_dec": decoder_padded, 120 | } 121 | else: 122 | return { 123 | "text": encoder_padded, 124 | } 125 | 126 | 127 | def get_mb_shape_from_ep(ep: ExecutionPlan): 128 | fw_shapes = [] 129 | for instr in ep.instructions: 130 | if isinstance(instr, ForwardPass): 131 | fw_shapes.append(instr.buffer_shapes) 132 | return fw_shapes 133 | 134 | 135 | def test_joint_data_loader(inputs_only=False): 136 | cluster_spec = get_uniform_cluster(2) 137 | if inputs_only: 138 | train_spec = TrainingSpec( 139 | "test_cm.pkl", 140 | cluster_spec, 141 | TransformerModelSpec(8, 0, 1024, 128, 65536, 128), 142 | 1, 143 | 2, 144 | 0, 145 | [0, 0, 0, 0, 1, 1, 1, 1], 146 | 800000, # ignore memory limit for this test 147 | prefetch_buffer_size=2, 148 | ) 149 | else: 150 | train_spec = TrainingSpec( 151 | "test_cm.pkl", 152 | cluster_spec, 153 | TransformerModelSpec(4, 4, 1024, 128, 65536, 128), 154 | 1, 155 | 2, 156 | 0, 157 | [0, 0, 0, 0, 1, 1, 1, 1], 158 | 800000, # ignore memory limit for this test 159 | prefetch_buffer_size=2, 160 | model_type="t5", 161 | ) 162 | rank = dist.get_rank() 163 | is_kv_host = rank == 0 164 | data_loader = DynaPipeDataLoader( 165 | train_spec, 166 | DummyDataset(256 * 10, inputs_only=inputs_only), 167 | pack_fn=dummy_pack_fn, 168 | constructor_fn=dummy_constructor_fn, 169 | is_kv_host=is_kv_host, 170 | node_rank=0, 171 | node_local_rank=rank, 172 | dp_rank=0, 173 | pp_rank=rank, 174 | batch_size=256, 175 | shuffle=False, 176 | num_workers=2, 177 | num_preprocess_workers=2, 178 | pin_memory=True, 179 | encoder_key="text_enc" if not inputs_only else "text", 180 | decoder_key="text_dec" if not inputs_only else None, 181 | ) 182 | batch_idx = 0 183 | for batch, ep in data_loader: 184 | if rank == 0: 185 | assert batch is not None 186 | ep_shapes = get_mb_shape_from_ep(ep) 187 | assert len(ep_shapes) == len(batch) 188 | for microbatch, ep_shape in zip(batch, ep_shapes): 189 | if not inputs_only: 190 | enc_seqlen, dec_seqlen = ( 191 | microbatch["text_enc"].shape[1], 192 | microbatch["text_dec"].shape[1], 193 | ) 194 | enc_mbs, dec_mbs = ( 195 | microbatch["text_enc"].shape[0], 196 | microbatch["text_dec"].shape[0], 197 | ) 198 | else: 199 | enc_seqlen = microbatch["text"].shape[1] 200 | enc_mbs = microbatch["text"].shape[0] 201 | dec_mbs = enc_mbs 202 | dec_seqlen = 0 203 | assert enc_mbs == dec_mbs 204 | # encoder only have ep_shape size 1 205 | assert len(ep_shape) == 1 206 | # test shape rounding 207 | assert enc_seqlen % 8 == 0 208 | assert dec_seqlen % 8 == 0 209 | mbs_from_ep = ep_shape[0][0] 210 | enc_seqlen_from_ep = ep_shape[0][1] 211 | assert mbs_from_ep == enc_mbs 212 | assert enc_seqlen_from_ep == enc_seqlen 213 | # get enc and decoder len from rank 1 214 | mbs_rank1_ep_tensor = torch.empty(1, dtype=torch.int64) 215 | encoder_ep_seqlen_tensor = torch.empty(1, dtype=torch.int64) 216 | decoder_ep_seqlen_tensor = torch.empty(1, dtype=torch.int64) 217 | dist.recv(tensor=mbs_rank1_ep_tensor, src=1) 218 | dist.recv(tensor=encoder_ep_seqlen_tensor, src=1) 219 | dist.recv(tensor=decoder_ep_seqlen_tensor, src=1) 220 | mbs_rank1_ep = mbs_rank1_ep_tensor.item() 221 | encoder_ep_seqlen = encoder_ep_seqlen_tensor.item() 222 | decoder_ep_seqlen = decoder_ep_seqlen_tensor.item() 223 | assert mbs_rank1_ep == enc_mbs 224 | assert dec_seqlen == decoder_ep_seqlen 225 | assert enc_seqlen == encoder_ep_seqlen 226 | print(f"batch {batch_idx} passed") 227 | batch_idx += 1 228 | else: 229 | assert batch is not None 230 | assert ep is not None 231 | ep_shapes = get_mb_shape_from_ep(ep) 232 | for ep_shape in ep_shapes: 233 | if not inputs_only: 234 | assert len(ep_shape) == 2 235 | assert ep_shape[0][0] == ep_shape[1][0] 236 | mbs_from_ep = ep_shape[0][0] 237 | enc_seqlen_from_ep = ep_shape[0][1] 238 | dec_seqlen_from_ep = ep_shape[1][1] 239 | else: 240 | assert len(ep_shape) == 1 241 | mbs_from_ep = ep_shape[0][0] 242 | enc_seqlen_from_ep = ep_shape[0][1] 243 | dec_seqlen_from_ep = 0 244 | mbs_tensor = torch.tensor(mbs_from_ep, dtype=torch.int64) 245 | enc_seqlen_tensor = torch.tensor( 246 | enc_seqlen_from_ep, dtype=torch.int64 247 | ) 248 | dec_seqlen_tensor = torch.tensor( 249 | dec_seqlen_from_ep, dtype=torch.int64 250 | ) 251 | dist.send(tensor=mbs_tensor, dst=0) 252 | dist.send(tensor=enc_seqlen_tensor, dst=0) 253 | dist.send(tensor=dec_seqlen_tensor, dst=0) 254 | dist.barrier() 255 | 256 | 257 | def test_joint_data_loader_hanging(): 258 | cluster_spec = get_uniform_cluster(4) 259 | train_spec = TrainingSpec( 260 | "test_cm.pkl", 261 | cluster_spec, 262 | TransformerModelSpec(4, 4, 1024, 128, 65536, 128), 263 | 1, 264 | 4, 265 | 0, 266 | [0, 0, 1, 1, 2, 2, 3, 3], 267 | 800000, # ignore memory limit for this test 268 | prefetch_buffer_size=32, 269 | model_type="t5", 270 | ) 271 | rank = dist.get_rank() 272 | data_loader = DynaPipeDataLoader( 273 | train_spec, 274 | DummyDataset(256 * 1000), 275 | pack_fn=dummy_pack_fn, 276 | constructor_fn=dummy_constructor_fn, 277 | is_kv_host=rank == 0, 278 | node_rank=0, 279 | node_local_rank=rank, 280 | dp_rank=0, 281 | pp_rank=rank, 282 | batch_size=256, 283 | shuffle=False, 284 | num_workers=2, 285 | num_preprocess_workers=32, 286 | pin_memory=True, 287 | ) 288 | for idx, (batch, ep) in enumerate(data_loader): 289 | if rank == 0: 290 | print("Progress: Iteration {}".format(idx)) 291 | dist.barrier() 292 | dist.barrier() 293 | 294 | 295 | def test_joint_data_loader_multiple_nodes(): 296 | cluster_spec = get_uniform_cluster(4) 297 | train_spec = TrainingSpec( 298 | "test_cm.pkl", 299 | cluster_spec, 300 | TransformerModelSpec(4, 4, 1024, 128, 65536, 128), 301 | 1, 302 | 4, 303 | 0, 304 | [0, 0, 1, 1, 2, 2, 3, 3], 305 | 800000, # ignore memory limit for this test 306 | prefetch_buffer_size=32, 307 | model_type="t5", 308 | ) 309 | rank = dist.get_rank() 310 | data_loader = DynaPipeDataLoader( 311 | train_spec, 312 | DummyDataset(256 * 1000), 313 | pack_fn=dummy_pack_fn, 314 | constructor_fn=dummy_constructor_fn, 315 | is_kv_host=rank == 0, 316 | node_rank=rank // 2, 317 | node_local_rank=rank % 2, 318 | node_size=2, 319 | dp_rank=0, 320 | pp_rank=rank, 321 | batch_size=256, 322 | shuffle=False, 323 | num_workers=2, 324 | num_preprocess_workers=32, 325 | pin_memory=True, 326 | ) 327 | for idx, (batch, ep) in enumerate(data_loader): 328 | if rank == 0: 329 | print("Progress: Iteration {}".format(idx)) 330 | dist.barrier() 331 | dist.barrier() 332 | 333 | 334 | def test_joint_data_loader_with_virtual_ranks(): 335 | cluster_spec = get_uniform_cluster(2) 336 | train_spec = TrainingSpec( 337 | "test_cm.pkl", 338 | cluster_spec, 339 | TransformerModelSpec(4, 4, 1024, 128, 65536, 128), 340 | 1, 341 | 2, 342 | 0, 343 | [0, 0, 1, 1, 0, 0, 1, 1], 344 | 800000, # ignore memory limit for this test 345 | prefetch_buffer_size=2, 346 | model_type="t5", 347 | ) 348 | rank = dist.get_rank() 349 | data_loader_0 = DynaPipeDataLoader( 350 | train_spec, 351 | DummyDataset(256 * 10), 352 | pack_fn=dummy_pack_fn, 353 | constructor_fn=dummy_constructor_fn, 354 | is_kv_host=True if rank == 0 else False, 355 | node_rank=0, 356 | node_local_rank=rank, 357 | dp_rank=0, 358 | pp_rank=rank, 359 | virtual_pp_rank=0, 360 | batch_size=256, 361 | shuffle=False, 362 | num_workers=2, 363 | num_preprocess_workers=2, 364 | pin_memory=True, 365 | ) 366 | data_loader_1 = DynaPipeDataLoader( 367 | train_spec, 368 | DummyDataset(256 * 10), 369 | pack_fn=dummy_pack_fn, 370 | constructor_fn=dummy_constructor_fn, 371 | is_kv_host=False, 372 | node_rank=0, 373 | node_local_rank=rank, 374 | node_size=1, 375 | dp_rank=0, 376 | pp_rank=rank, 377 | virtual_pp_rank=1, 378 | batch_size=256, 379 | shuffle=False, 380 | num_workers=2, 381 | pin_memory=True, 382 | ) 383 | for it, ((batch0, ep0), (batch1, ep1)) in enumerate( 384 | zip(data_loader_0, data_loader_1) 385 | ): 386 | assert len(batch0) == len( 387 | batch1 388 | ), "batch size mismatch ({}, {}) at iter {}".format( 389 | len(batch0), len(batch1), it 390 | ) 391 | for mb0, mb1 in zip(batch0, batch1): 392 | assert torch.equal(mb0["encoder_input"], mb1["encoder_input"]) 393 | assert torch.equal(mb0["decoder_input"], mb1["decoder_input"]) 394 | assert ep0 == ep1 395 | dist.barrier() 396 | 397 | 398 | os.environ["MASTER_ADDR"] = "localhost" 399 | os.environ["MASTER_PORT"] = "12355" 400 | torch.distributed.init_process_group("gloo") 401 | # test hanging issue 402 | # test_joint_data_loader_hanging() 403 | # test multi-node preprocessing 404 | # test_joint_data_loader_multiple_nodes() 405 | # test without virtual ranks 406 | test_joint_data_loader(inputs_only=True) 407 | # test with virtual ranks 408 | # test_joint_data_loader_with_virtual_ranks() 409 | -------------------------------------------------------------------------------- /tests/test_dev_assign_validation.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import pytest 5 | 6 | from dynapipe.model import TransformerModelSpec, get_uniform_cluster 7 | from dynapipe.pipe.utils import validate_device_assignment 8 | 9 | 10 | def get_uniform_model(n_encoder_layers, n_decoder_layers): 11 | return TransformerModelSpec( 12 | n_encoder_layers, n_decoder_layers, 1024, 128, 65536, 128 13 | ) 14 | 15 | 16 | @pytest.mark.parametrize( 17 | "model_spec, cluster_spec, device_assignment, expected", 18 | [ 19 | # linear 20 | ( 21 | get_uniform_model(4, 4), 22 | get_uniform_cluster(8), 23 | [0, 1, 2, 3, 4, 5, 6, 7], 24 | ( 25 | "linear", 26 | set(["wait-free-cyclic", "1F1B"]), 27 | 1, 28 | 1, 29 | ), 30 | ), 31 | # interleaved 1 32 | ( 33 | get_uniform_model(4, 4), 34 | get_uniform_cluster(4), 35 | [0, 1, 2, 3, 0, 1, 2, 3], 36 | ( 37 | "interleaved", 38 | set(["wait-free-cyclic", "interleaved-1F1B"]), 39 | 1, 40 | 2, 41 | ), 42 | ), 43 | # interleaved 2 44 | ( 45 | get_uniform_model(8, 4), 46 | get_uniform_cluster(4), 47 | [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3], 48 | ( 49 | "interleaved", 50 | set(["wait-free-cyclic", "interleaved-1F1B"]), 51 | 1, 52 | 3, 53 | ), 54 | ), 55 | # multiple layer per virtual layer 56 | ( 57 | get_uniform_model(4, 4), 58 | get_uniform_cluster(4), 59 | [0, 0, 1, 1, 2, 2, 3, 3], 60 | ( 61 | "linear", 62 | set(["wait-free-cyclic", "1F1B"]), 63 | 2, 64 | 1, 65 | ), 66 | ), 67 | # multiple layer per virtual layer, interleaved 68 | ( 69 | get_uniform_model(8, 8), 70 | get_uniform_cluster(4), 71 | [0, 0, 1, 1, 2, 2, 3, 3, 0, 0, 1, 1, 2, 2, 3, 3], 72 | ( 73 | "interleaved", 74 | set(["wait-free-cyclic", "interleaved-1F1B"]), 75 | 2, 76 | 2, 77 | ), 78 | ), 79 | # decoder only models like GPT (note we specify all layers as 80 | # encoder layers since we assume decoder layers are those containing 81 | # encoder-decoder attention) 82 | ( 83 | get_uniform_model(4, 0), 84 | get_uniform_cluster(4), 85 | [0, 1, 2, 3], 86 | ( 87 | "linear", 88 | set(["wait-free-cyclic", "1F1B"]), 89 | 1, 90 | 1, 91 | ), 92 | ), 93 | ( 94 | get_uniform_model(8, 0), 95 | get_uniform_cluster(4), 96 | [0, 0, 1, 1, 2, 2, 3, 3], 97 | ( 98 | "linear", 99 | set(["wait-free-cyclic", "1F1B"]), 100 | 2, 101 | 1, 102 | ), 103 | ), 104 | # single gpu 105 | ( 106 | get_uniform_model(4, 0), 107 | get_uniform_cluster(1), 108 | [0, 0, 0, 0], 109 | ( 110 | "linear", 111 | set(["wait-free-cyclic"]), 112 | 4, 113 | 1, 114 | ), 115 | ), 116 | # other 117 | ( 118 | get_uniform_model(4, 4), 119 | get_uniform_cluster(4), 120 | [0, 1, 2, 3, 3, 2, 1, 0], 121 | ( 122 | "other", 123 | set(["wait-free-cyclic"]), 124 | 1, 125 | 1, 126 | ), 127 | ), 128 | ], 129 | ) 130 | def test_valid_device_assignments( 131 | model_spec, cluster_spec, device_assignment, expected 132 | ): 133 | ( 134 | device_assignment_type, 135 | valid_schedule_methods, 136 | n_actual_layers_per_virtual_layer, 137 | n_chunks_per_device, 138 | ) = validate_device_assignment(model_spec, cluster_spec, device_assignment) 139 | valid_schedule_methods = set(valid_schedule_methods) 140 | assert device_assignment_type == expected[0] 141 | assert valid_schedule_methods == expected[1] 142 | assert n_actual_layers_per_virtual_layer == expected[2] 143 | assert n_chunks_per_device == expected[3] 144 | 145 | 146 | def test_incorrect_appear_order(): 147 | with pytest.raises(AssertionError): 148 | validate_device_assignment( 149 | get_uniform_model(4, 4), 150 | get_uniform_cluster(4), 151 | [0, 2, 1, 3, 0, 2, 1, 3], 152 | ) 153 | 154 | 155 | def test_incorrect_n_devices(): 156 | with pytest.raises(AssertionError): 157 | validate_device_assignment( 158 | get_uniform_model(4, 4), 159 | get_uniform_cluster(8), 160 | [0, 1, 2, 3, 0, 1, 2, 3], 161 | ) 162 | 163 | 164 | def test_incorrect_interleaving_no_decoder(): 165 | with pytest.raises(NotImplementedError): 166 | validate_device_assignment( 167 | get_uniform_model(0, 0), 168 | get_uniform_cluster(4), 169 | [0, 1, 2, 3, 0, 1, 2, 3], 170 | ) 171 | 172 | 173 | if __name__ == "__main__": 174 | pytest.main([__file__]) 175 | -------------------------------------------------------------------------------- /tests/test_instruction_optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import os 5 | import pickle 6 | 7 | import pytest 8 | 9 | from dynapipe.pipe.instruction_optimizer import ( 10 | InstructionOptimizer, 11 | _is_forward, 12 | ) 13 | from dynapipe.pipe.instructions import * # noqa: F403 14 | 15 | eps_prefix = "./test_scheduler/" 16 | eps_paths = [ 17 | os.path.join(eps_prefix, x) 18 | for x in os.listdir(eps_prefix) 19 | if x.endswith(".pkl") 20 | ] 21 | 22 | 23 | def load_eps(path): 24 | with open(path, "rb") as f: 25 | serialized_eps = pickle.load(f)[-1] 26 | eps = [ExecutionPlan.deserialize(ep) for ep in serialized_eps] 27 | return eps 28 | 29 | 30 | @pytest.mark.parametrize("eps_path", eps_paths) 31 | def test_inject_comm_finish_instrs(eps_path): 32 | input_eps = load_eps(eps_path) 33 | 34 | for ep in input_eps: 35 | input_instrs = ep.instructions 36 | optimizer = InstructionOptimizer([], n_stages=ep.nstages) 37 | output_instrs = optimizer._inject_comm_finish_instrs(input_instrs) 38 | start_keys = set() 39 | for instruction in output_instrs: 40 | if isinstance(instruction, CommunicationStartInstruction): 41 | key = ( 42 | instruction.microbatch, 43 | instruction.stage, 44 | _is_forward(instruction), 45 | ) 46 | start_keys.add(key) 47 | elif isinstance(instruction, CommunicationFinishInsturction): 48 | key = ( 49 | instruction.microbatch, 50 | instruction.stage, 51 | _is_forward(instruction), 52 | ) 53 | assert ( 54 | key in start_keys 55 | ), "Finish instruction without start instruction: {}".format( 56 | instruction 57 | ) 58 | start_keys.remove(key) 59 | assert ( 60 | len(start_keys) == 0 61 | ), "Start instruction without finish instruction: {}".format( 62 | start_keys 63 | ) 64 | 65 | 66 | @pytest.mark.parametrize("eps_path", eps_paths) 67 | def test_allocate_buffer(eps_path): 68 | input_eps = load_eps(eps_path) 69 | 70 | for ep in input_eps: 71 | input_instrs = ep.instructions 72 | optimizer = InstructionOptimizer([], n_stages=ep.nstages) 73 | output_instrs, n_buffer_slots = optimizer._allocate_buffers( 74 | input_instrs 75 | ) 76 | buffer_slots = [None] * n_buffer_slots 77 | for instr in output_instrs: 78 | if isinstance( 79 | instr, (RecvActivationStart, RecvGradStart, LoadInput) 80 | ): 81 | # instructions that populates buffer slots 82 | for buffer_id, buffer_shape in zip( 83 | instr.buffer_ids, instr.buffer_shapes 84 | ): 85 | assert buffer_slots[buffer_id] is None 86 | buffer_slots[buffer_id] = buffer_shape 87 | elif isinstance( 88 | instr, 89 | ( 90 | ForwardPass, 91 | BackwardPass, 92 | SendActivationStart, 93 | SendGradStart, 94 | ), 95 | ): 96 | # instruction that reads buffer slots 97 | for buffer_id, buffer_shape in zip( 98 | instr.buffer_ids, instr.buffer_shapes 99 | ): 100 | assert buffer_slots[buffer_id] == buffer_shape 101 | elif isinstance(instr, FreeBuffer): 102 | # instruction that frees buffer slots 103 | for buffer_id in instr.buffer_ids: 104 | assert buffer_slots[buffer_id] is not None 105 | buffer_slots[buffer_id] = None 106 | assert all([buffer_slot is None for buffer_slot in buffer_slots]) 107 | 108 | 109 | if __name__ == "__main__": 110 | pytest.main([__file__]) 111 | -------------------------------------------------------------------------------- /tests/test_instructions.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from typing import Union 5 | 6 | import pytest 7 | 8 | from dynapipe.pipe.instructions import * # noqa: F403 9 | 10 | comm_start_instrs = [ 11 | SendActivationStart, 12 | SendGradStart, 13 | RecvActivationStart, 14 | RecvGradStart, 15 | ] 16 | comm_finish_instrs = [ 17 | SendActivationFinish, 18 | SendGradFinish, 19 | RecvActivationFinish, 20 | RecvGradFinish, 21 | ] 22 | buffer_instrs = [LoadInput, FreeBuffer] 23 | 24 | 25 | def _test_serialization(instr: Union[PipeInstruction, ExecutionPlan]): 26 | serialized = instr.serialize() 27 | assert isinstance(serialized, bytes) 28 | if isinstance(instr, PipeInstruction): 29 | deserialized, remaining_bytes = PipeInstruction.deserialize(serialized) 30 | assert len(remaining_bytes) == 0 31 | else: 32 | deserialized = ExecutionPlan.deserialize(serialized) 33 | assert instr == deserialized 34 | # test casting to str 35 | serialized_casted = ( 36 | serialized.decode("iso-8859-1").encode().decode().encode("iso-8859-1") 37 | ) 38 | if isinstance(instr, PipeInstruction): 39 | deserialized_casted, remaining_bytes = PipeInstruction.deserialize( 40 | serialized_casted 41 | ) 42 | assert len(remaining_bytes) == 0 43 | else: 44 | deserialized_casted = ExecutionPlan.deserialize(serialized_casted) 45 | assert instr == deserialized_casted 46 | 47 | 48 | @pytest.mark.parametrize("instr_cls", comm_start_instrs) 49 | @pytest.mark.parametrize("n_tensors", [1, 2, 3]) 50 | @pytest.mark.parametrize("comm_dims", [0, 1, 2, 3]) 51 | def test_serialization_comm_start_same_dim( 52 | instr_cls: Type[CommunicationStartInstruction], 53 | n_tensors: int, 54 | comm_dims: int, 55 | ): 56 | buffer_ids = list(range(n_tensors)) 57 | buffer_shapes = [ 58 | tuple([2 for _ in range(comm_dims)]) for _ in range(n_tensors) 59 | ] 60 | instr = instr_cls( 61 | 0, 1, peer=0, buffer_shapes=buffer_shapes, buffer_ids=buffer_ids 62 | ) 63 | _test_serialization(instr) 64 | 65 | 66 | @pytest.mark.parametrize("instr_cls", comm_start_instrs) 67 | def test_serialization_comm_start_diff_dim( 68 | instr_cls: CommunicationStartInstruction, 69 | ): 70 | buffer_ids = [0, 1] 71 | buffer_shapes = [(2, 2, 2), (2, 2)] 72 | instr = instr_cls( 73 | 0, 1, peer=0, buffer_shapes=buffer_shapes, buffer_ids=buffer_ids 74 | ) 75 | _test_serialization(instr) 76 | 77 | 78 | @pytest.mark.parametrize("instr_cls", comm_finish_instrs) 79 | @pytest.mark.parametrize("n_buffers", [1, 2, 3]) 80 | def test_serialization_comm_finish( 81 | instr_cls: CommunicationFinishInsturction, n_buffers: int 82 | ): 83 | instr = instr_cls(0, 1, peer=0, buffer_ids=list(range(n_buffers))) 84 | _test_serialization(instr) 85 | 86 | 87 | def test_serialization_forward(): 88 | instr = ForwardPass(0, 1, [0, 1]) 89 | _test_serialization(instr) 90 | 91 | 92 | @pytest.mark.parametrize("first_bw_layer", [True, False]) 93 | def test_serialization_backward(first_bw_layer: bool): 94 | instr = BackwardPass(0, 1, [0, 1], first_bw_layer=first_bw_layer) 95 | _test_serialization(instr) 96 | 97 | 98 | @pytest.mark.parametrize("instr_cls", buffer_instrs) 99 | @pytest.mark.parametrize("n_buffers", [1, 2, 3]) 100 | @pytest.mark.parametrize("buffer_dims", [0, 1, 2, 3]) 101 | def test_serialization_buffer( 102 | instr_cls: Union[Type[LoadInput], Type[FreeBuffer]], 103 | n_buffers: int, 104 | buffer_dims: int, 105 | ): 106 | buffer_ids = list(range(n_buffers)) 107 | buffer_shapes = [ 108 | tuple([2 for _ in range(buffer_dims)]) for _ in range(n_buffers) 109 | ] 110 | if instr_cls == LoadInput: 111 | instr = instr_cls( 112 | 0, 113 | 1, 114 | buffer_shapes=buffer_shapes, 115 | buffer_ids=buffer_ids, 116 | ) 117 | else: 118 | instr = instr_cls(buffer_ids=buffer_ids) 119 | _test_serialization(instr) 120 | 121 | 122 | def test_serialization_exec_plan(): 123 | instructions = [ 124 | LoadInput(0, 0, buffer_shapes=[(2, 2, 2), (2, 2)], buffer_ids=[0, 1]), 125 | SendActivationStart( 126 | 0, 1, peer=0, buffer_shapes=[(2, 2, 2), (2, 2)], buffer_ids=[0, 1] 127 | ), 128 | SendActivationFinish(0, 1, peer=0, buffer_ids=[0, 1]), 129 | RecvActivationStart( 130 | 0, 1, peer=0, buffer_shapes=[(2, 2, 2), (2, 2)], buffer_ids=[0, 1] 131 | ), 132 | RecvActivationFinish(0, 1, peer=0, buffer_ids=[0, 1]), 133 | ForwardPass(0, 1, buffer_ids=[0, 1]), 134 | SendGradStart( 135 | 0, 1, peer=0, buffer_shapes=[(2, 2, 2), (2, 2)], buffer_ids=[0, 1] 136 | ), 137 | SendGradFinish(0, 1, peer=0, buffer_ids=[0, 1]), 138 | RecvGradStart( 139 | 0, 1, peer=0, buffer_shapes=[(2, 2, 2), (2, 2)], buffer_ids=[0, 1] 140 | ), 141 | RecvGradFinish(0, 1, peer=0, buffer_ids=[0, 1]), 142 | BackwardPass(0, 1, buffer_ids=[0, 1], first_bw_layer=True), 143 | BackwardPass(0, 1, buffer_ids=[0, 1], first_bw_layer=False), 144 | FreeBuffer(buffer_ids=[0, 1]), 145 | ] 146 | exec_plan = ExecutionPlan( 147 | instructions, 148 | micro_batches=8, 149 | nranks=4, 150 | nstages=4, 151 | rank=1, 152 | assigned_stages=[0, 1], 153 | recompute_method=RecomputeMethod.SELECTIVE, 154 | num_pipe_buffers=2, 155 | ) 156 | _test_serialization(exec_plan) 157 | 158 | 159 | if __name__ == "__main__": 160 | pytest.main([__file__]) 161 | -------------------------------------------------------------------------------- /tests/test_kv_store.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # Note: this test requires torch 5 | # to run this test, exec: 6 | # DYNAPIPE_DEBUG=DEBUG DYNAPIPE_LOGGING_DEBUG_DIR=./test_debug \ 7 | # torchrun --standalone --nnodes=1 --nproc_per_node=1 test_kv_store.py 8 | 9 | import multiprocessing as mp 10 | 11 | from dynapipe.pipe.data_loader import ( 12 | _get_from_shared_kv_store, 13 | _init_kv_store, 14 | _put_to_shared_kv_store, 15 | ) 16 | 17 | 18 | def _producer_process(max_iters, buffer_size=32): 19 | try: 20 | kv_store, _, _ = _init_kv_store(is_master=True) 21 | # set all ack keys 22 | for i in range(buffer_size): 23 | kv_store.set(f"key_{i}_ack".format(i), "1") 24 | kv_store.set(f"key_{i}_r0_ack".format(i), "1") 25 | for i in range(max_iters): 26 | key = "key_{}".format(i % buffer_size) 27 | payload = str(i) 28 | _put_to_shared_kv_store(kv_store, key, payload) 29 | print("[producer] put key: {}".format(key), flush=True) 30 | import time 31 | 32 | time.sleep(2) 33 | except Exception as e: 34 | import traceback 35 | 36 | traceback.print_exc() 37 | raise e 38 | 39 | 40 | def _consumer_process(max_iters, buffer_size=32): 41 | try: 42 | kv_store, _, _ = _init_kv_store(is_master=False) 43 | for i in range(max_iters): 44 | key = "key_{}".format(i % buffer_size) 45 | payload = _get_from_shared_kv_store( 46 | kv_store, key, 0, 1, decode=True 47 | ) 48 | assert payload == str(i) 49 | print("[consumer] got key: {}".format(key), flush=True) 50 | except Exception as e: 51 | import traceback 52 | 53 | traceback.print_exc() 54 | raise e 55 | 56 | 57 | def test_kv_store(): 58 | max_iters = 1000 59 | buffer_size = 32 60 | producer = mp.Process( 61 | target=_producer_process, args=(max_iters, buffer_size) 62 | ) 63 | consumer = mp.Process( 64 | target=_consumer_process, args=(max_iters, buffer_size) 65 | ) 66 | producer.start() 67 | consumer.start() 68 | consumer.join() 69 | producer.join() 70 | 71 | 72 | if __name__ == "__main__": 73 | test_kv_store() 74 | -------------------------------------------------------------------------------- /tests/test_logger.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import os 5 | import shutil 6 | import sys 7 | 8 | import pytest 9 | 10 | LOGGING_TEST_DIR = "./logger_test" 11 | CREATED_FILES = [ 12 | "test_warning.log", 13 | "test_stderr_warning.log", 14 | "test_multiline_stderr_warning.log", 15 | ] 16 | 17 | os.environ["DYNAPIPE_DEBUG"] = "DEBUG" 18 | os.environ["DYNAPIPE_LOGGING_DEBUG_DIR"] = LOGGING_TEST_DIR 19 | 20 | from dynapipe.utils.logger import create_logger # noqa: E402 21 | 22 | 23 | @pytest.fixture(scope="module", autouse=True) 24 | def prepare_and_cleanup(): 25 | if os.path.exists(LOGGING_TEST_DIR): 26 | for files_to_remove in CREATED_FILES: 27 | fn = os.path.join(LOGGING_TEST_DIR, files_to_remove) 28 | if os.path.exists(fn): 29 | os.remove(fn) 30 | yield 31 | if os.path.exists(LOGGING_TEST_DIR): 32 | shutil.rmtree(LOGGING_TEST_DIR) 33 | 34 | 35 | def _get_number_of_lines(string: str): 36 | return len(string.strip().split("\n")) 37 | 38 | 39 | def test_logger_warning(capfd): 40 | logger = create_logger( 41 | "test_logger", prefix="Warn Test", log_file="test_warning.log" 42 | ) 43 | logger.warning("This is a warning.") 44 | 45 | _, stderr_contents = capfd.readouterr() 46 | # stderr should be colored 47 | assert "\x1b[0m" in stderr_contents 48 | # stderr should contain one line 49 | assert _get_number_of_lines(stderr_contents) == 1 50 | fn = os.path.join(LOGGING_TEST_DIR, "test_warning.log") 51 | with open(fn, "r") as f: 52 | log_contents = f.read() 53 | # log file should not be colored 54 | assert "\x1b[0m" not in log_contents 55 | # log file should contain one line 56 | assert _get_number_of_lines(log_contents) == 1 57 | 58 | 59 | def test_stderr_warning(capfd): 60 | _ = create_logger( 61 | "test_logger", prefix="Warn Test", log_file="test_stderr_warning.log" 62 | ) 63 | print("This is a warning from stderr.", file=sys.stderr) 64 | _, stderr_contents = capfd.readouterr() 65 | # stderr should be colored 66 | assert "\x1b[0m" in stderr_contents 67 | # stderr should contain one line 68 | assert _get_number_of_lines(stderr_contents) == 1 69 | fn = os.path.join(LOGGING_TEST_DIR, "test_stderr_warning.log") 70 | with open(fn, "r") as f: 71 | log_contents = f.read() 72 | # log file should not be colored 73 | assert "\x1b[0m" not in log_contents 74 | # log file should contain one line 75 | assert _get_number_of_lines(log_contents) == 1 76 | 77 | 78 | def test_stderr_multiline(capfd): 79 | _ = create_logger( 80 | "test_logger", 81 | prefix="Warn Test", 82 | log_file="test_multiline_stderr_warning.log", 83 | ) 84 | print( 85 | "This is a warning from stderr.\nThis is the second line.", 86 | file=sys.stderr, 87 | ) 88 | _, stderr_contents = capfd.readouterr() 89 | # stderr should be colored 90 | assert "\x1b[0m" in stderr_contents 91 | # stderr should contain two lines 92 | assert _get_number_of_lines(stderr_contents) == 2 93 | fn = os.path.join(LOGGING_TEST_DIR, "test_multiline_stderr_warning.log") 94 | with open(fn, "r") as f: 95 | log_contents = f.read() 96 | # log file should not be colored 97 | assert "\x1b[0m" not in log_contents 98 | # log file should contain one line 99 | assert _get_number_of_lines(log_contents) == 2 100 | 101 | 102 | if __name__ == "__main__": 103 | # test_stderr_multiline() 104 | pytest.main([__file__]) 105 | -------------------------------------------------------------------------------- /tests/test_memory_opt/test_cuda_memory_stats.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import pickle 5 | from pprint import pprint 6 | 7 | import torch 8 | 9 | from dynapipe.memory_opt.cuda_caching_allocator import ( 10 | get_allocator, 11 | override_allocator, 12 | ) 13 | 14 | 15 | def test_cuda_stats(): 16 | override_allocator() 17 | allocator = get_allocator() 18 | 19 | a = torch.zeros((1024, 1024), device="cuda") 20 | del a 21 | b = torch.zeros((1024, 1024), device="cuda") # noqa: F841 22 | c = torch.zeros((64, 32), device="cuda") # noqa: F841 23 | 24 | pickled_snapshot = allocator.get_memory_snapshot() 25 | 26 | py_snapshot = pickle.loads(pickled_snapshot) 27 | pprint(py_snapshot["segments"]) 28 | 29 | print( 30 | "Peak reserved memory: {} MB".format( 31 | allocator.peak_reserved_cuda_memory() / 1e6 32 | ) 33 | ) 34 | print( 35 | "Peak allocated memory: {} MB".format( 36 | allocator.peak_allocated_cuda_memory() / 1e6 37 | ) 38 | ) 39 | print( 40 | "Peak requested memory: {} MB".format( 41 | allocator.peak_requested_cuda_memory() / 1e6 42 | ) 43 | ) 44 | print( 45 | "Current reserved memory: {} MB".format( 46 | allocator.current_reserved_cuda_memory() / 1e6 47 | ) 48 | ) 49 | print( 50 | "Current allocated memory: {} MB".format( 51 | allocator.current_allocated_cuda_memory() / 1e6 52 | ) 53 | ) 54 | print( 55 | "Current requested memory: {} MB".format( 56 | allocator.current_requested_cuda_memory() / 1e6 57 | ) 58 | ) 59 | del b 60 | del c 61 | 62 | # reset stats 63 | allocator.reset_peak_stats() 64 | print( 65 | "Peak reserved memory: {} MB".format( 66 | allocator.peak_reserved_cuda_memory() / 1e6 67 | ) 68 | ) 69 | print( 70 | "Peak allocated memory: {} MB".format( 71 | allocator.peak_allocated_cuda_memory() / 1e6 72 | ) 73 | ) 74 | print( 75 | "Peak requested memory: {} MB".format( 76 | allocator.peak_requested_cuda_memory() / 1e6 77 | ) 78 | ) 79 | 80 | 81 | if __name__ == "__main__": 82 | test_cuda_stats() 83 | -------------------------------------------------------------------------------- /tests/test_memory_opt/test_host_caching_allocator.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import random 5 | import time 6 | 7 | import torch 8 | from tqdm import trange 9 | 10 | from dynapipe.memory_opt.host_caching_allocator import HostCachingAllocator 11 | 12 | 13 | def test_host_caching_allocator(preallocate=False): 14 | hca = HostCachingAllocator() 15 | dtype_dict = { 16 | 1: torch.uint8, 17 | 2: torch.int8, 18 | 3: torch.float32, 19 | 4: torch.float16, 20 | } 21 | pinned_tensors = [] 22 | torch_tensors = [] 23 | random.seed(42) 24 | torch.manual_seed(0) 25 | total_malloc_time = 0 26 | if preallocate: 27 | x = hca.malloc((32 * 1024 * 1024 * 4 * 20,), torch.uint8) # noqa: F841 28 | del x 29 | for step in trange(1000): 30 | mbs = random.randint(1, 32) 31 | enc_seqlen = random.randint(1, 1024) 32 | dec_seqlen = random.randint(1, 1024) 33 | shape = (mbs, enc_seqlen, dec_seqlen) 34 | dtype = dtype_dict[random.randint(1, 4)] 35 | start = time.time() 36 | pinned_tensor = hca.malloc(shape, dtype) 37 | total_malloc_time += time.time() - start 38 | if dtype == torch.uint8 or dtype == torch.int8: 39 | torch_tensor = torch.randint( 40 | 0, 128, shape, dtype=dtype, device="cpu" 41 | ) 42 | else: 43 | torch_tensor = torch.rand(shape, dtype=dtype, device="cpu") 44 | pinned_tensor.copy_(torch_tensor) 45 | pinned_tensors.append(pinned_tensor) 46 | torch_tensors.append(torch_tensor) 47 | 48 | if step > 20: 49 | # free some tensors 50 | avail_idx = [ 51 | i for i, x in enumerate(pinned_tensors) if x is not None 52 | ] 53 | idx = random.choice(avail_idx) 54 | # check if the tensor is still the same 55 | assert torch.allclose(pinned_tensors[idx], torch_tensors[idx]) 56 | pinned_tensors[idx] = None 57 | torch_tensors[idx] = None 58 | print(f"total mallcs: {hca.n_backend_mallocs}") 59 | print(f"total malloc time: {total_malloc_time:.2f} seconds") 60 | 61 | 62 | if __name__ == "__main__": 63 | # test_host_caching_allocator(preallocate=False) 64 | test_host_caching_allocator(preallocate=True) 65 | -------------------------------------------------------------------------------- /tests/test_scheduler/test_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import pickle 5 | 6 | import pytest 7 | 8 | from dynapipe.model import ( 9 | DynaPipeMicrobatch, 10 | DynaPipeMinibatch, 11 | get_uniform_cluster, 12 | get_uniform_microbatch, 13 | ) 14 | from dynapipe.schedule_opt.execution_planner import ( 15 | ExecutionPlan, 16 | optimize_schedule, 17 | ) 18 | 19 | 20 | def hetero_minibatch() -> DynaPipeMinibatch: 21 | fw_times = [4000] * 4 + [2000] * 4 # 8 layer ende 22 | # fw_times = [4000] * 4 + [4000] * 4 23 | microbatch_multiplier = [1, 0.8, 1.2, 0.9, 1.1, 0.7, 1.4, 0.6] 24 | memory_multiplier = [1, 0.8, 1.2, 0.9, 1.1, 0.7, 1.4, 0.6] 25 | microbatches = [] 26 | for i in range(len(microbatch_multiplier)): 27 | current_fw_times = [ 28 | fw_times[j] * microbatch_multiplier[i] 29 | for j in range(len(fw_times)) 30 | ] 31 | current_bw_times = [2 * t for t in current_fw_times] 32 | microbatch = DynaPipeMicrobatch(str(i)) 33 | microbatch.set_fw_exec_times(current_fw_times) 34 | microbatch.set_bw_exec_times(current_bw_times) 35 | microbatch.set_fw_comm_size( 36 | [2 * microbatch_multiplier[i]] * (len(fw_times) - 1) 37 | ) 38 | microbatch.set_bw_comm_size( 39 | [2 * microbatch_multiplier[i]] * (len(fw_times) - 1) 40 | ) 41 | microbatch.set_model_state_memory([4000] * len(fw_times)) 42 | microbatch.set_model_stored_activation_memory( 43 | [8000 * memory_multiplier[i]] * len(fw_times) 44 | ) 45 | microbatch.set_model_peak_activation_memory( 46 | [16000 * memory_multiplier[i]] * len(fw_times) 47 | ) 48 | microbatch.set_activation_shapes( 49 | [[(64, 128, 512)]] * (len(fw_times) // 2) 50 | + [[(64, 128, 512), (64, 128, 512)]] * (len(fw_times) // 2) 51 | ) 52 | microbatches.append(microbatch) 53 | minibatch = DynaPipeMinibatch("test", microbatches) 54 | return minibatch 55 | 56 | 57 | @pytest.mark.parametrize("use_het_batch", [False, True]) 58 | @pytest.mark.parametrize( 59 | "device_assignment, sch_type, expected_file", 60 | [ 61 | ([0, 0, 1, 1, 2, 2, 3, 3], "1F1B", "1f1b"), 62 | ([0, 0, 1, 1, 2, 2, 3, 3], "wait-free-cyclic", "wfcyclic"), 63 | ([0, 1, 2, 3, 0, 1, 2, 3], "interleaved-1F1B", "interleaved_1f1b"), 64 | ([0, 1, 2, 3, 0, 1, 2, 3], "wait-free-cyclic", "interleaved_wfcyclic"), 65 | ([0, 1, 2, 3, 3, 2, 1, 0], "wait-free-cyclic", "zigzag_wfcyclic"), 66 | ], 67 | ) 68 | def test_minibatch( 69 | use_het_batch, 70 | device_assignment, 71 | sch_type, 72 | expected_file, 73 | try_permutations=False, 74 | memory_limit=float("inf"), 75 | ): 76 | cluster = get_uniform_cluster(4) 77 | if not use_het_batch: 78 | microbatches = [] 79 | for i in range(8): 80 | microbatch = get_uniform_microbatch(8, comm_ratio=0.1) 81 | microbatch.name = str(i) 82 | microbatches.append(microbatch) 83 | minibatch = DynaPipeMinibatch("test", microbatches) 84 | else: 85 | minibatch = hetero_minibatch() 86 | ( 87 | _, 88 | _, 89 | _, 90 | min_makespan, 91 | min_stats, 92 | min_instructions, 93 | ) = optimize_schedule( 94 | sch_type, 95 | minibatch, 96 | cluster, 97 | device_assignment, 98 | try_permutations=try_permutations, 99 | include_memory_stats=True, 100 | progress_bar=False, 101 | memory_limit=memory_limit, 102 | ) 103 | n_stages = ( 104 | max([instr.stage for instrs in min_instructions for instr in instrs]) 105 | + 1 106 | ) 107 | eps = [ 108 | ExecutionPlan(instrs, 8, 4, n_stages, i, [0, 1]) 109 | for i, instrs in enumerate(min_instructions) 110 | ] 111 | serialized_eps = [ep.serialize() for ep in eps] 112 | prefix = "uniform" if not use_het_batch else "heter" 113 | expected_result_path = "{}_{}.pkl".format(prefix, expected_file) 114 | # Uncomment to get generated trace 115 | # trace_path = "{}_{}_trace.json".format(prefix, expected_file) 116 | # with open(trace_path, "w") as f: 117 | # import json 118 | # json.dump(min_stats[-1], f) 119 | # Uncomment to generate new expected results 120 | # with open(expected_result_path, "wb") as f: 121 | # pickle.dump((min_makespan, min_stats[1], serialized_eps), f) 122 | with open(expected_result_path, "rb") as f: 123 | expected_result = pickle.load(f) 124 | assert expected_result == (min_makespan, min_stats[1], serialized_eps) 125 | 126 | 127 | if __name__ == "__main__": 128 | pytest.main([__file__]) 129 | --------------------------------------------------------------------------------