├── CHANGELOG.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── ml_goodput_measurement ├── __init__.py ├── src │ ├── checkpoint_badput_calculator.py │ ├── gcp_metrics.py │ ├── goodput.py │ ├── goodput_cache.py │ ├── goodput_utils.py │ └── monitoring.py └── tests │ ├── checkpoint_badput_calculator_test.py │ ├── gcp_metrics_test.py │ ├── goodput_cache_test.py │ ├── goodput_test.py │ └── monitoring_test.py └── pyproject.toml /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | 23 | ## [0.0.11] - 2025-06-06 24 | 25 | * Support for monitoring performance degradations. 26 | * Support for monitoring rolling window Goodput. 27 | * Force upload of final metrics and safe exit. 28 | 29 | ## [0.0.10] - 2025-04-28 30 | 31 | * Support for custom badput events which are synchronous and training-overlapped. 32 | * Handling of edge case caching scenario. 33 | 34 | ## [0.0.9] - SKIPPED 35 | 36 | * Used for external testing. Please upgrade to 0.0.10. 37 | 38 | ## [0.0.8] - 2025-04-03 39 | 40 | * Fix computation of ideal step time when step_times is empty. 41 | 42 | ## [0.0.7] - 2025-03-24 43 | 44 | * Cache updates to Other/Unknown Badput. 45 | * Exclude monitoring asynchronous Badput types in GCM. 46 | * Total and last step updates with hidden events. 47 | * Interval Query Monitoring in GCM. 48 | 49 | ## [0.0.6] - 2025-03-17 50 | 51 | * Updates to data loading Badput buckets (Separated into Async & Sync). 52 | * Short term fix to Pathways SuspendResume anomalous step time detection. 53 | * Updates to account for Pathways Elastic Training. 54 | * Automatic asynchronous upload of goodput, badput and step time deviation metrics to GCM. 55 | 56 | ## [0.0.5] - 2025-02-03 57 | 58 | * Goodput Cache and library improvements. 59 | * Query and Monitor API support for checkpoint save and restore. 60 | * Interval Query API support. 61 | * Query and Monitor API support for step time deviation. 62 | 63 | ## [0.0.4] - 2024-09-13 64 | 65 | * Add Badput breakdown to GoodputMonitor. 66 | * Add Checkpoint Badput Calculator backend. 67 | * Return last recorded step from Goodput query API. 68 | * Bug Fixes 69 | * Fix a potential race-condition with Tensorboard write to GCS. 70 | * Fix zero job time issue on long running jobs 71 | 72 | ## [0.0.3] - 2024-05-28 73 | 74 | * Compute and discount Badput from first step after start or restart. 75 | * Compute and discount Badput due to anomalous step times (Pathways only). 76 | * Badput recording APIs 77 | * Some Badput computation APIs (TPU initialization , training preparation, data loading, program startup) 78 | * Goodput monitoring API to asynchronously query and upload Goodput to Tensorboard. 79 | * Bug Fixes 80 | * Fix Goodput calculation with disruptions 81 | * Fix some Cloud Logging latency and batching issues. 82 | 83 | ## [0.0.2] - 2024-02-29 84 | 85 | * Bug Fixes 86 | * Fixes a typing mismatch in total step time calculation. 87 | * Code and documentation cleanup 88 | 89 | ## [0.0.1] - 2024-02-26 90 | 91 | * Initial release of ML Goodput Measurement PyPi package 92 | * Feature: Contains the Goodput module which allows logging and retrieval of training job's overall productive Goodput 93 | 94 | [0.0.11]: https://github.com/AI-Hypercomputer/ml-goodput-measurement/compare/v0.0.10...v0.0.11 95 | [0.0.10]: https://github.com/AI-Hypercomputer/ml-goodput-measurement/compare/v0.0.8...v0.0.10 96 | [0.0.8]: https://github.com/AI-Hypercomputer/ml-goodput-measurement/compare/v0.0.7...v0.0.8 97 | [0.0.7]: https://github.com/AI-Hypercomputer/ml-goodput-measurement/compare/v0.0.6...v0.0.7 98 | [0.0.6]: https://github.com/AI-Hypercomputer/ml-goodput-measurement/compare/v0.0.5...v0.0.6 99 | [0.0.5]: https://github.com/AI-Hypercomputer/ml-goodput-measurement/compare/v0.0.4...v0.0.5 100 | [0.0.4]: https://github.com/AI-Hypercomputer/ml-goodput-measurement/compare/v0.0.3...v0.0.4 101 | [0.0.3]: https://github.com/AI-Hypercomputer/ml-goodput-measurement/compare/v0.0.2...v0.0.3 102 | [0.0.2]: https://github.com/AI-Hypercomputer/ml-goodput-measurement/compare/v0.0.1...v0.0.2 103 | [0.0.1]: https://github.com/AI-Hypercomputer/ml-goodput-measurement/releases/tag/v0.0.1 -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to contribute 2 | 3 | We'd love to accept your patches and contributions to this project. 4 | 5 | ## Before you begin 6 | 7 | ### Sign our Contributor License Agreement 8 | 9 | Contributions to this project must be accompanied by a 10 | [Contributor License Agreement](https://cla.developers.google.com/about) (CLA). 11 | You (or your employer) retain the copyright to your contribution; this simply 12 | gives us permission to use and redistribute your contributions as part of the 13 | project. 14 | 15 | If you or your current employer have already signed the Google CLA (even if it 16 | was for a different project), you probably don't need to do it again. 17 | 18 | Visit to see your current agreements or to 19 | sign a new one. 20 | 21 | ### Review our community guidelines 22 | 23 | This project follows 24 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 25 | 26 | ## Contribution process 27 | 28 | ### Code reviews 29 | 30 | All submissions, including submissions by project members, require review. We 31 | use GitHub pull requests for this purpose. Consult 32 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 33 | information on using pull requests. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 16 | # ML Goodput Measurement 17 | 18 | ## Overview 19 | 20 | ML Goodput Measurement is a library intended to be used with Cloud accelerators 21 | to log necessary information and query a job's Goodput and Badput Breakdown. It 22 | can be pip installed to import its modules, and retrieve information about a 23 | training job's overall productive Goodput and sources of Badput. The package 24 | exposes API interfaces to log useful information from the user application and 25 | query Goodput for the job run, gain insight into the productivity of ML 26 | workloads and utilization of compute resources. 27 | 28 | The package also exposes Goodput Monitoring APIs which allow asynchronous query 29 | and export of the job's Goodput, Badput and Step Time Deviation to Tensorboard 30 | with configurable upload interval. 31 | 32 | ## Components 33 | 34 | 35 | The ML Goodput Measurement library consists of the following main components: 36 | 37 | - `GoodputRecorder` 38 | 39 | - `GoodputCalculator` 40 | - `GoodputMonitor` 41 | - `GoodputCache` 42 | 43 | 44 | The `GoodputRecorder` 45 | exposes APIs to the client to export key timestamps while a training job makes 46 | progress, namely APIs that allow logging of productive step time and total job 47 | run time. The library will serialize and store this data in Google Cloud 48 | Logging. 49 | 50 | The `GoodputCalculator` exposes APIs to compute Goodput based on the 51 | recorded data. Cloud Logging handles its internal operations asynchronously. 52 | The recommended way to compute Goodput is to run an analysis program separate 53 | from the training application, either on a CPU instance or on the users' 54 | development machine. 55 | 56 | Under the hood, the `GoodputCalculator` uses a `GoodputCache` which is an 57 | internal component that locally caches pre-computations and useful logs such 58 | that repeated computations can be made inexpensive. 59 | 60 | The `GoodputMonitor` exposes APIs to query and upload goodput and step time 61 | deviation data to Tensorboard asynchronously. It does this by instantiating a 62 | `GoodputCaluclator` under the hood. 63 | 64 | ## Installation 65 | 66 | To install the ML Goodput Measurement package, run the following command on the 67 | VM or machine you want to query or monitor your workload from: 68 | 69 | ```bash 70 | pip install ml-goodput-measurement 71 | ``` 72 | 73 | ## Usage 74 | 75 | The usage of this package requires the setup of a Google Cloud project with 76 | billing enabled to properly use Google Cloud Logging. If you don't have a Google 77 | Cloud project, or if you don't have billing enabled for your Google Cloud 78 | project, then do the following: 79 | 80 | 1. In the Google Cloud console, on the project selector page, 81 | [select or create a Google Cloud project](https://cloud.google.com/resource-manager/docs/creating-managing-projects). 82 | 83 | 2. Make sure that billing is enabled for your Google Cloud project. Instructions can be found [here](https://cloud.google.com/billing/docs/how-to/verify-billing-enabled#console) 84 | 85 | 3. [Enable](https://console.cloud.google.com/flows/enableapi?apiid=logging.googleapis.com&_ga=2.27841276.1571868865.1726250448-123998259.1726107009) the Cloud Logging API. 86 | 87 | To run your training on Cloud accelerator, set up the environment by following 88 | instructions [here](https://cloud.google.com/tpu/docs/setup-gcp-account). 89 | 90 | To learn more about Google Cloud Logging, visit this [page](https://cloud.google.com/logging/docs). 91 | 92 | ### Access Scopes 93 | 94 | You will need both read and write access scopes for cloud logging on both the 95 | GPU or TPU and CPU node pools. Full cloud logging access is granted by the 96 | following access scope during node pool creation: 97 | 98 | - `https://www.googleapis.com/auth/cloud-platform` 99 | 100 | XPK adds this access scope to the GPU, TPU and CPU node pools, so XPK is the 101 | recommended method to create clusters and node-pools in you intend to run 102 | your workloads on GKE. 103 | 104 | Instructions on how to create clusters using XPK can be 105 | found [here](https://github.com/AI-Hypercomputer/xpk/blob/main/README.md#cluster-create) 106 | and how to create workloads using XPK can be found 107 | [here](https://github.com/AI-Hypercomputer/xpk/blob/main/README.md#workload-create). 108 | 109 | > **_NOTE:_** Access Scopes are immutable and workloads can only be migrated 110 | to new node pools with required access scopes. Access scopes on already created 111 | clusters cannot be updated. 112 | 113 | ### Import 114 | 115 | To use this package, import the `goodput` module: 116 | 117 | ```python 118 | from ml_goodput_measurement import goodput 119 | from ml_goodput_measurement import monitoring 120 | ``` 121 | 122 | ### Define the name of the Google Cloud Logging logger. 123 | 124 | Create a run-specific logger name where Cloud Logging entries can be written to 125 | and read from. 126 | 127 | > **IMPORTANT:** Please use a unique `run_name` for each individual experiment 128 | or workload that you intend to monitor separately. If you unintentionally re-use 129 | `run_name` or `goodput_logger_name` in the same storage bucket of a GCP project, 130 | your cumulative Goodput metrics may be inaccurately taking previous runs into 131 | account. 132 | 133 | For example: 134 | 135 | ```python 136 | goodput_logger_name = f'goodput_{config.run_name}' # Here run_name is unique. 137 | ``` 138 | 139 | ### Create a `GoodputRecorder` object 140 | 141 | Next, create a recorder object with the following parameters: 142 | 143 | 1. `job_name`: The full run name of the job. 144 | 2. `logger_name`: The name of the Cloud Logging logger object (created in the previous step). 145 | 3. `logging_enabled`: Whether or not this process has Cloud Logging enabled. 146 | 147 | > **_NOTE:_** For a multi-worker setup, please ensure that only one worker 148 | writes the logs to avoid the duplication. In JAX, for example, the check 149 | could be `if jax.process_index() == 0` 150 | 151 | > **_NOTE:_** `logging_enabled` defaults to `False` and Goodput computations 152 | cannot be completed if no logs are ever written. 153 | 154 | For example: 155 | 156 | ```python 157 | goodput_recorder = goodput.GoodputRecorder( 158 | job_name=config.run_name, 159 | logger_name=goodput_logger_name, 160 | logging_enabled=(jax.process_index() == 0) 161 | ) 162 | ``` 163 | 164 | > **_NOTE:_** JAX initialization should be complete before this call. 165 | 166 | ### Record Data with `GoodputRecorder` 167 | 168 | #### Record Job Start and End Time 169 | 170 | Use the recorder object to record the job's overall start and end time. 171 | 172 | For example: 173 | 174 | ```python 175 | def main(argv: Sequence[str]) -> None: 176 | # Initialize configs… 177 | goodput_recorder.record_job_start_time(datetime.datetime.now()) 178 | # Device Initialization and device scanning… 179 | # Set up other things for the main training loop… 180 | # Main training loop 181 | train_loop(config) 182 | goodput_recorder.record_job_end_time(datetime.datetime.now()) 183 | ``` 184 | 185 | #### Record Step Time 186 | 187 | Use the recorder object to record a step's start time using 188 | `record_step_start_time(step_count)`: 189 | 190 | For example: 191 | 192 | ```python 193 | def train_loop(config, state=None): 194 | # Set up mesh, model, state, checkpoint manager… 195 | 196 | # Initialize functional train arguments and model parameters… 197 | 198 | # Define the compilation 199 | 200 | for step in np.arange(start_step, config.steps): 201 | goodput_recorder.record_step_start_time(step) 202 | # Training step… 203 | 204 | return state 205 | ``` 206 | 207 | #### Record Device Initialization, Training Preparation and Data Loading Time 208 | 209 | - Use the recorder object to record Device Initialization time using 210 | `record_tpu_init_start_time` and `record_tpu_init_end_time`. 211 | - Use the recorder object to record Training Preparation time using 212 | `record_training_preparation_start_time` and 213 | `record_training_preparation_end_time`. 214 | - Use the recorder object to record Data Loading time using 215 | `record_data_loading_start_time` and `record_data_loading_end_time`. 216 | 217 | For example: 218 | 219 | ```python 220 | def train_loop(config, state=None): 221 | goodput_recorder.record_tpu_init_start_time() 222 | # Set up mesh, model, state, checkpoint manager… 223 | goodput_recorder.record_tpu_init_end_time() 224 | goodput_recorder.record_training_preparation_start_time() 225 | # Set up training set, initialize functional train args and model parameters… 226 | # Define the compilation 227 | # Set up any metrics collectors 228 | goodput_recorder.record_training_preparation_end_time() 229 | 230 | for step in np.arange(start_step, config.steps): 231 | goodput_recorder.record_data_loading_start_time() 232 | example_batch = load_next_batch(data_iterator, example_batch, config) 233 | goodput_recorder.record_data_loading_end_time() 234 | goodput_recorder.record_step_start_time(step) 235 | # Training step… 236 | 237 | return state 238 | ``` 239 | 240 | #### Record Custom Badput Events (e.g., Evaluation, SDC Checks) 241 | 242 | - Use the recorder object to record the **start** of a custom badput event using 243 | `record_custom_badput_event_start_time(custom_badput_event_type='your_event_name')`. 244 | - Use the recorder object to record the **end** of a custom badput event using 245 | `record_custom_badput_event_end_time(custom_badput_event_type='your_event_name')`. 246 | 247 | Use these APIs when you want to account for time spent on operations that 248 | block the training loop and use accelerator resources, do not contribute to 249 | productive training and occur while training is in progress — such as step 250 | evaluations, SDC checks, or re-compilations. 251 | 252 | For example: 253 | 254 | ```python 255 | def train_loop(config, state=None): 256 | goodput_recorder.record_training_preparation_start_time() 257 | # Initialize training config, setup model, load checkpoint... 258 | goodput_recorder.record_training_preparation_end_time() 259 | 260 | for step in range(config.steps): 261 | goodput_recorder.record_data_loading_start_time() 262 | batch = load_batch(train_data) 263 | goodput_recorder.record_data_loading_end_time() 264 | 265 | goodput_recorder.record_step_start_time(step) 266 | # Run training step... 267 | run_train_step(step, state) 268 | 269 | if step % config.eval_interval == 0: 270 | # Record a custom badput event for evaluation 271 | goodput_recorder.record_custom_badput_event_start_time( 272 | custom_badput_event_type="eval_step") 273 | run_step_evaluation(model, val_data) 274 | goodput_recorder.record_custom_badput_event_end_time( 275 | custom_badput_event_type="eval_step") 276 | 277 | if step % config.sdc_check_interval == 0: 278 | # Record a custom badput event for SDC check 279 | goodput_recorder.record_custom_badput_event_start_time( 280 | custom_badput_event_type="sdc_check") 281 | run_sdc_check(state) 282 | goodput_recorder.record_custom_badput_event_end_time( 283 | custom_badput_event_type="sdc_check") 284 | 285 | return state 286 | ``` 287 | 288 | > **_NOTE:_** The `custom_badput_event_type` string should be descriptive and 289 | consistent (e.g., "eval_step", "sdc_check"), to ensure accurate aggregation and 290 | reporting in badput breakdowns. 291 | 292 | ### Retrieve Goodput with `GoodputCalculator` 293 | 294 | In order to retrieve the Goodput of a job run, all you need to do is instantiate 295 | a `GoodputCalculator` object with the job's run name and the Cloud Logging 296 | logger name used to record data for that job run. Then call the 297 | `get_job_goodput` API to get the computed Goodput for the job run. 298 | 299 | It is recommended to make the `get_job_goodput` calls for a job run from an 300 | instance that runs elsewhere from your training machine. 301 | 302 | #### Create a `GoodputCalculator` object 303 | 304 | Create the calculator object: 305 | 306 | ```python 307 | goodput_logger_name = f'goodput_{config.run_name}' # You can choose your own logger name. 308 | goodput_calculator = goodput.GoodputCalculator(job_name=config.run_name, logger_name=goodput_logger_name) 309 | ``` 310 | 311 | If you want to enable Pathways, turn on the `using_pathways` flag: 312 | 313 | ```python 314 | goodput_logger_name = f'goodput_{config.run_name}' # You can choose your own logger name. 315 | goodput_calculator = goodput.GoodputCalculator(job_name=config.run_name, logger_name=goodput_logger_name, using_pathways=True) 316 | ``` 317 | 318 | #### Retrieve Goodput 319 | 320 | Finally, call the `get_job_goodput` API to retrieve Goodput for the entire job 321 | run. This API takes an optional parameter `include_badput_breakdown`. which 322 | defaults to `False`. 323 | 324 | The returned result is a tuple of the job’s Goodput at query-time, a dictionary 325 | mapping various sources of Badput and their corresponding percentages and the 326 | last recorded step. If `include_badput_breakdown` is not set, an empty 327 | dictionary for Badput is returned. 328 | 329 | If you are only interested in Goodput: 330 | 331 | ```python 332 | total_goodput, _, _ = goodput_calculator.get_job_goodput() 333 | print(f"Total job goodput: {total_goodput:.2f}%") 334 | ``` 335 | 336 | #### Retrieve Badput Breakdown 337 | 338 | Badput breakdown is dictionary representation of various sources of Badput 339 | mapped to its corresponding value. Badput is the percentage of time spent by the 340 | job doing work that is not training to the total lifetime of the job. This 341 | includes time spent doing device initialization, training preparation, 342 | program startup, checkpoint loading, compilation or re-compilation, data loading, 343 | checkpoint saving, custom badput events, wasted progress and time lost due 344 | to disruptions. 345 | 346 | Following Badput Breakdown buckets are supported by the library at this time: 347 | 348 | ```python 349 | # Supported Badput Types 350 | class BadputType(enum.Enum): 351 | """The type of Badput.""" 352 | 353 | TPU_INITIALIZATION = 1 354 | TRAINING_PREP = 2 355 | PROGRAM_STARTUP = 3 356 | DATA_LOADING_SYNC = 4 357 | DATA_LOADING_ASYNC = 5 # This does not affect Goodput 358 | UNPRODUCTIVE_CHECKPOINT_SAVE_TIME = 6 359 | UNPRODUCTIVE_CHECKPOINT_RESTORE_TIME = 7 360 | WASTED_PROGRESS_FROM_DISRUPTION = 8 361 | CUSTOM_BADPUT_EVENTS = 9 362 | OTHER = 10 363 | ``` 364 | 365 | #### Badput Breakdown Details 366 | 367 | - Accelerator Initialization Time (TPU_INITIALIZATION) 368 | 369 | This is the time spent on device discovery, slice initialization, 370 | device driver re-initialization and reset, security setup, initialization of 371 | pre-mapped buffers and more. 372 | 373 | - Training Preparation Time (TRAINING_PREP) 374 | 375 | This is the time spent on the creation of checkpoint managers, checkpoint 376 | loading, running mesh and model optimizers and more. 377 | 378 | - Program Startup Time (PROGRAM_STARTUP) 379 | 380 | This is the time spent on framework specific function transformations 381 | (such as JAX tracing), compilation tasks, runtime initialization etc. 382 | 383 | - Data Loading Time (DATA_LOADING_SYNC) 384 | 385 | This is the time spent on loading each batch of data for the training at a 386 | step to continue. This should be a small contribution to Badput if parallel 387 | data loading is used. 388 | 389 | Async data loading is accumulated overlapping with training steps and is 390 | non-blocking, therefore is not unproductive time. The time spent on overlapped 391 | data loading is stored in BadputType.DATA_LOADING_ASYNC, but does **not** 392 | affect overall Goodput of the workload. 393 | 394 | - Checkpointing Time (UNPRODUCTIVE_CHECKPOINT_SAVE_TIME, UNPRODUCTIVE_CHECKPOINT_RESTORE_TIME) 395 | 396 | This is the time spent on saving a checkpoint and restoring a checkpoint. 397 | 398 | Depending on the type of checkpointing technology used by the program, there 399 | could be unproductive time while saving a checkpoint. When checkpointing is 400 | synchronous, the save operation will block training progress until it is complete. 401 | 402 | During asynchronous checkpointing, the model parameters or weights have to be 403 | transferred from the device memory to the host memory which is a blocking 404 | operation on the device. After the transfer, the device can proceed with model 405 | training while the CPU saves the checkpoint to storage in the background. The 406 | first blocking operation contributes to unproductive checkpoint time. 407 | 408 | If auto checkpointing is used, the checkpoint save operation is initiated upon 409 | detection of a planned disruption signal. The save operation in type of 410 | checkpointing is synchronous resulting in time lost to Badput. 411 | 412 | - Wasted Progress due to Disruption (WASTED_PROGRESS_FROM_DISRUPTION) 413 | 414 | Based on checkpointing frequency, a disruption may result in time lost in the 415 | form of wasted progress, i.e. time that was spent on productive training but 416 | lost after restart as well as time lost for the infrastructure to restart the 417 | workload. 418 | 419 | When there is a disruption, Badput is expected to accumulate in 420 | each of the following buckets after restart: 421 | 422 | - Accelerator Initialization 423 | - Training Preparation 424 | - Program Startup 425 | - Wasted Progress due to Disruption 426 | 427 | - Custom Badput Events (CUSTOM_BADPUT_EVENTS) 428 | 429 | Your application can optionally use record and monitor badput from custom 430 | synchronous (blocking training) and overlapping (between training steps) 431 | events. These events are are generally used for useful non-training activity on 432 | the accelerator while training is in progress such as performing SDC checks 433 | or evaluations. 434 | 435 | If you are interested in retrieving Badput Breakdown along with Goodput: 436 | 437 | ```python 438 | goodput, badput_breakdown, last_step = goodput_calculator.get_job_goodput(include_badput_breakdown=True) 439 | print(f"Last step recorded: {last_step}") 440 | print(f"Goodput: {goodput:.2f}%") 441 | print(f"Badput due to TPU initialization: {badput_breakdown[goodput.BadputType.TPU_INITIALIZATION]:.2f}%") 442 | print(f"Badput due to training preparation: {badput_breakdown[goodput.BadputType.TRAINING_PREP]:.2f}%") 443 | print(f"Badput due to program startup: {badput_breakdown[goodput.BadputType.PROGRAM_STARTUP]:.2f}%") 444 | print(f"Badput due to data loading: {badput_breakdown[goodput.BadputType.DATA_LOADING_SYNC]:.2f}%") 445 | print(f"Badput due to disruption and wasted progress: {badput_breakdown[goodput.BadputType.WASTED_PROGRESS_FROM_DISRUPTION]:.2f}%") 446 | print(f"Badput due to checkpoint save: {badput_breakdown[goodput.BadputType.UNPRODUCTIVE_CHECKPOINT_SAVE_TIME]:.2f}%") 447 | print(f"Badput due to checkpoint restore: {badput_breakdown[goodput.BadputType.UNPRODUCTIVE_CHECKPOINT_RESTORE_TIME]:.2f}%") 448 | print(f"Badput due to step evaluation: {badput_breakdown[goodput.BadputType.CUSTOM_BADPUT_EVENTS].get('EVAL_STEP', 0.0):.2f}%") 449 | print(f"Badput due to SDC checks: {badput_breakdown[goodput.BadputType.CUSTOM_BADPUT_EVENTS].get('SDC_CHECK', 0.0):.2f}%") 450 | print(f"Badput from unknown source: {badput_breakdown[goodput.BadputType.OTHER]:.2f}%") 451 | ``` 452 | 453 | #### Interval Query Goodput and Badput 454 | 455 | If you are interested in retrieving Goodput and Badput of the workload within a 456 | specific window of time, the `GoodputCalculator` exposes the 457 | `get_job_goodput_interval` API which computes metrics between the start and end 458 | of this window. 459 | 460 | This API also returns the last step recorded for the job. the total job time in 461 | this window and the number of disruptions within the interval window. 462 | 463 | > **_IMPORTANT:_** **Use this API if** you know the exact window of time within 464 | the workload's total run time that you are interested in. 465 | 466 | > **_IMPORTANT:_** **Do NOT use this API if** your workload has been manually 467 | disrupted. 468 | 469 | > **_IMPORTANT:_** **Do NOT use this API if** you have accidentally re-used a 470 | previous `run_name`. 471 | 472 | ```python 473 | # Example usage 474 | start_time_str = "2024-12-16 1:05:00" 475 | start_time_utc = convert_pst_to_utc(start_time_str) 476 | end_time_str = "2024-12-17 2:00:00" 477 | end_time_utc = convert_pst_to_utc(end_time_str) 478 | current_goodput, badput_breakdown, last_step, total_time, disruptions = goodput_calculator.get_job_goodput_interval(start_time_utc, end_time_utc) 479 | ``` 480 | 481 | ### Monitor Goodput with `GoodputMonitor` 482 | 483 | In order to monitor the Goodput of a job run on Tensorboard, all you need to do 484 | is instantiate a `GoodputMonitor` object with the job's run name, cloud logger 485 | name and Goodput monitoring configurations (as described below). Then call the 486 | `start_goodput_uploader` API to asynchronously query and upload measured Goodput 487 | to the specified Tensorboard directory. 488 | 489 | #### Create a `GoodputMonitor` object 490 | 491 | Create a `GoodputMonitor` object with the following parameters: 492 | 493 | 1. `job_name`: The full run name of the job. 494 | 2. `logger_name`: The name of the Cloud Logging logger object (created in the previous step). 495 | 3. `tensorboard_dir`: The directory to write TensorBoard data to. 496 | 4. `upload_interval`: The time interval at which to query and upload data to TensorBoard. 497 | 5. `monitoring_enabled`: Whether or not monitoring is enabled. 498 | If the application is interested in monitoring Goodput, it should set 499 | this value to True. Only one worker should enable monitoring. 500 | 6. `include_badput_breakdown`: Whether to query and upload badput breakdown 501 | data to Tensorboard. 502 | 503 | > **_NOTE:_** Please ensure that only **one** worker enables monitoring of Goodput. 504 | In JAX, for example, the check could be `if jax.process_index() == 0` 505 | 506 | For example: 507 | 508 | ```python 509 | goodput_logger_name = f'goodput_{config.run_name}' # You can choose your own logger name. 510 | goodput_monitoring_enabled = config.monitor_goodput and jax.process_index() == 0 # Check for configs whether or not the enable monitoring. 511 | 512 | goodput_monitor = monitoring.GoodputMonitor( 513 | job_name=config.run_name, 514 | logger_name=logger_name, 515 | tensorboard_dir=config.tensorboard_dir, 516 | upload_interval=config.goodput_upload_interval_seconds, 517 | monitoring_enabled=True, 518 | include_badput_breakdown=True, 519 | ) 520 | ``` 521 | 522 | If you want to enable Pathways, turn on the `pathway_enabled` flag: 523 | 524 | ```python 525 | goodput_logger_name = f'goodput_{config.run_name}' # You can choose your own logger name. 526 | goodput_monitoring_enabled = config.monitor_goodput and jax.process_index() == 0 # Check for configs whether or not the enable monitoring. 527 | 528 | goodput_monitor = monitoring.GoodputMonitor( 529 | job_name=config.run_name, 530 | logger_name=logger_name, 531 | tensorboard_dir=config.tensorboard_dir, 532 | upload_interval=config.goodput_upload_interval_seconds, 533 | monitoring_enabled=True, 534 | include_badput_breakdown=True, 535 | pathway_enabled=True 536 | ) 537 | ``` 538 | 539 | ### Monitor Cumulative Goodput Metrics 540 | 541 | #### Start asynchronous "query and upload" of Goodput 542 | 543 | Call the `start_goodput_uploader` API to spin off a thread which continuously 544 | queries and uploads cumulative Goodput metrics to Tensorboard & Google Cloud 545 | Monitoring. 546 | 547 | > **_NOTE:_** This will upload computed metrics to Google Cloud Monitoring 548 | by default. 549 | 550 | Following metrics are uploaded: 551 | 552 | - Productive Time (Goodput) 553 | - Unproductive Time (Badput Breakdown) 554 | - Total Elapsed Time 555 | - Maximum Productive Step Count 556 | - Disruptions Count 557 | - Step Time Deviation 558 | - Ideal Step Time 559 | 560 | ```python 561 | goodput_monitor.start_goodput_uploader() 562 | ``` 563 | 564 | #### Stop the Goodput Uploader 565 | 566 | Call the `stop_goodput_uploader` API to perform a final upload of all metrics 567 | and safely exit. 568 | 569 | > **_NOTE:_** This will stop all cumulative metrics upload threads. 570 | 571 | ```python 572 | goodput_monitor.stop_goodput_uploader() 573 | ``` 574 | 575 | ### Monitor Rolling Window Goodput Metrics 576 | 577 | #### Start asynchronous "query and upload" of Rolling Window Goodput 578 | 579 | Call the `start_rolling_window_goodput_uploader` API to start a background 580 | thread that continuously queries and uploads **rolling window goodput metrics** 581 | to Google Cloud Monitoring. 582 | 583 | You must provide a list of window durations in seconds (e.g., `[60, 300, 900]` 584 | for 1 min, 5 min, and 15 min windows). 585 | 586 | Following metrics are uploaded: 587 | 588 | - Rolling Window Goodput 589 | - Rolling Window Badput Breakdown 590 | 591 | ```python 592 | goodput_monitor.start_rolling_window_goodput_uploader(rolling_windows_seconds=[60, 300, 900]) 593 | ``` 594 | 595 | #### Stop the Rolling Window Goodput Uploader 596 | 597 | Call the `stop_goodput_rolling_window_uploader` API to perform a final upload 598 | of rolling window metrics and safely shut down the background uploader thread. 599 | 600 | > **_NOTE:_** This will stop all rolling window metrics upload threads. 601 | 602 | ```python 603 | goodput_monitor.stop_goodput_rolling_window_uploader() 604 | ``` 605 | 606 | #### Visualize on Tensorboard 607 | 608 | 1. Make sure you have `tensorboard-plugin-profile`, `tensorflow` and `tensorboard` packages installed 609 | 2. Follow instructions [here](https://cloud.google.com/tpu/docs/profile-tpu-vm#start_profiling_the_model_training) to start the Tensorboard server 610 | 611 | #### Access Metrics on Google Cloud Monitoring 612 | 613 | By default, performance data is automatically sent to Google Cloud Monitoring, 614 | enabling visualization and alerting on dashboards. This includes both cumulative 615 | and rolling window metrics. 616 | 617 | The metrics currently sent to Google Cloud Monitoring are: 618 | 619 | - **Cumulative Goodput:** 620 | [workload/goodput_time](https://cloud.google.com/monitoring/api/metrics_gcp#:~:text=workload/goodput_time) 621 | - **Cumulative Badput:** 622 | [workload/badput_time](https://cloud.google.com/monitoring/api/metrics_gcp#:~:text=workload/badput_time) 623 | - **Rolling Window Goodput:** 624 | [workload/interval_goodput](https://cloud.google.com/monitoring/api/metrics_gcp#:~:text=workload/interval_goodput) 625 | - **Rolling Window Badput:** 626 | [workload/interval_badput](https://cloud.google.com/monitoring/api/metrics_gcp#:~:text=workload/interval_badput) 627 | - **Total Elapsed Time:** 628 | [workload/total_elapsed_time](https://cloud.google.com/monitoring/api/metrics_gcp#:~:text=workload/total_elapsed_time) 629 | - **Maximum Productive Step:** 630 | [workload/max_productive_steps](https://cloud.google.com/monitoring/api/metrics_gcp#:~:text=workload/max_productive_steps) 631 | - **Disruption Count:** 632 | [workload/disruptions](https://cloud.google.com/monitoring/api/metrics_gcp#:~:text=workload/disruptions) 633 | - **Step Time Deviation:** 634 | [workload/step_time_deviation](https://cloud.google.com/monitoring/api/metrics_gcp#:~:text=workload/step_time_deviation) 635 | - **Ideal Step Time:** 636 | [workload/performance](https://cloud.google.com/monitoring/api/metrics_gcp#:~:text=workload/performance) 637 | 638 | This feature leverages Google VM metadata (project ID, location, accelerator type) 639 | and supports replica IDs for uniquely identifying workloads in multi-replica 640 | deployments. 641 | 642 | ```python 643 | 644 | gcp_options = goodput_utils.GCPOptions( 645 | project_id=None, # If None, the library will automatically identify from GCE internal metadata 646 | location=None, # If None, the library will automatically identify from GCE internal metadata 647 | replica_id='0', # Default is '0' 648 | acc_type=None, # If None, the library will automatically identify from GCE internal metadata 649 | enable_gcp_goodput_metrics=True, 650 | enable_gcp_step_deviation_metrics=True, 651 | ) 652 | 653 | goodput_monitor = monitoring.GoodputMonitor( 654 | job_name=config.run_name, 655 | logger_name=logger_name, 656 | tensorboard_dir=config.tensorboard_dir, 657 | upload_interval=config.goodput_upload_interval_seconds, 658 | monitoring_enabled=True, 659 | include_badput_breakdown=True, 660 | include_step_deviation=True, 661 | configured_ideal_step_time=None, # Optional, the library will compute ideal step time if it is not provided 662 | gcp_options=gcp_options 663 | ) 664 | ``` 665 | 666 | If you do not wish to send metrics to Google Cloud Monitoring then please set 667 | the flag `enable_gcp_goodput_metrics` to `False` for disabling goodput metrics 668 | and `enable_gcp_step_deviation_metrics` to `False` for disabling step deviation 669 | metrics while creating the GCPOptions object. 670 | 671 | Setting `monitoring_enabled` to `False` will disable both tensorboard and GCM 672 | monitoring. 673 | 674 | ```python 675 | 676 | gcp_options = goodput_utils.GCPOptions( 677 | project_id=None, # If None, the library will automatically identify from GCE internal metadata 678 | location=None, # If None, the library will automatically identify from GCE internal metadata 679 | replica_id='0', # Default is '0' 680 | acc_type=None, # If None, the library will automatically identify from GCE internal metadata 681 | enable_gcp_goodput_metrics=False, 682 | enable_gcp_step_deviation_metrics=False, 683 | ) 684 | 685 | 686 | goodput_monitor = monitoring.GoodputMonitor( 687 | job_name=config.run_name, 688 | logger_name=logger_name, 689 | tensorboard_dir=config.tensorboard_dir, 690 | upload_interval=config.goodput_upload_interval_seconds, 691 | monitoring_enabled=True, 692 | include_badput_breakdown=True, 693 | include_step_deviation=True, 694 | configured_ideal_step_time=None, 695 | gcp_options=gcp_options, 696 | ) 697 | ``` 698 | 699 | If you want to monitor Goodput and Badput metrics computed in a specific window 700 | of time, you can use the `start_goodput_interval_uploader` monitoring API. 701 | 702 | #### Create the `GoodputMonitor` with `enable_gcp_goodput_metrics` set to `True` in `GCPOptions` 703 | 704 | ```python 705 | 706 | gcp_options = goodput_utils.GCPOptions( 707 | project_id=None, # If None, the library will automatically identify from GCE internal metadata 708 | location=None, # If None, the library will automatically identify from GCE internal metadata 709 | replica_id='0', # Default is '0' 710 | acc_type=None, # If None, the library will automatically identify from GCE internal metadata 711 | enable_gcp_goodput_metrics=True, 712 | ) 713 | 714 | goodput_monitor = monitoring.GoodputMonitor( 715 | job_name=config.run_name, 716 | logger_name=logger_name, 717 | tensorboard_dir=config.tensorboard_dir, 718 | upload_interval=config.goodput_upload_interval_seconds, 719 | monitoring_enabled=True, 720 | include_badput_breakdown=True, 721 | gcp_options=gcp_options 722 | ) 723 | ``` -------------------------------------------------------------------------------- /ml_goodput_measurement/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from cloud_goodput.ml_goodput_measurement.src import checkpoint_badput_calculator 16 | from cloud_goodput.ml_goodput_measurement.src import gcp_metrics 17 | from cloud_goodput.ml_goodput_measurement.src import goodput 18 | from cloud_goodput.ml_goodput_measurement.src import goodput_cache 19 | from cloud_goodput.ml_goodput_measurement.src import goodput_utils 20 | from cloud_goodput.ml_goodput_measurement.src import monitoring 21 | -------------------------------------------------------------------------------- /ml_goodput_measurement/src/checkpoint_badput_calculator.py: -------------------------------------------------------------------------------- 1 | """Checkpoint Badput Calculator class.""" 2 | 3 | import argparse 4 | import dataclasses 5 | import statistics 6 | from typing import Dict, List, Optional 7 | 8 | import google.cloud.logging as google_cloud_logging 9 | 10 | 11 | _JOB_NAME = 'checkpoint_job' 12 | _LOGGER_NAME = 'checkpoint_logger' 13 | 14 | _STEP = 'step' 15 | _EVENT_TYPE = 'event_type' 16 | _DIRECTORY = 'directory' 17 | 18 | _WAIT_FOR_PREV_DURATION_SECS = 'wait_for_prev_duration_secs' 19 | 20 | _CHECKPOINTER_SAVE_DURATION_SECS = 'checkpointer_blocking_duration_secs' 21 | _CHECKPOINTER_RESTORE_DURATION_SECS = 'checkpointer_duration_secs' 22 | 23 | _GET_OLD_STEPS_DURATION_SECS = 'get_old_steps_duration_secs' 24 | 25 | _CHECKPOINT_MANAGER_SAVE_DURATION_SECS = 'checkpoint_manager_blocking_duration_secs' 26 | _CHECKPOINT_MANAGER_RESTORE_DURATION_SECS = 'checkpoint_manager_duration_secs' 27 | 28 | _BROADCAST_DURATION_SECS = 'broadcast_duration_secs' 29 | 30 | OPERATION_TYPE_SAVE = 'save' 31 | OPERATION_TYPE_RESTORE = 'restore' 32 | OPERATION_TYPE_EMERGENCY_RESTORE = 'emergency_restore' 33 | 34 | OPERATION_TYPE_LOCAL = 'local' 35 | OPERATION_TYPE_PERSISTENT = 'persistent' 36 | OPERATION_TYPE_PERSISTENT_AND_LOCAL = 'persistent_and_local' 37 | 38 | _CLOUD_LOGGING_PAGE_SIZE = 10000 39 | 40 | 41 | @dataclasses.dataclass 42 | class SaveCheckpointManagerVerticalStepStats: 43 | """Vertical step statistics for save operation.""" 44 | total_checkpoint_manager_blocking_time: float = 0.0 45 | average_checkpoint_manager_blocking_time: float = 0.0 46 | minimum_checkpoint_manager_blocking_time: float = 0.0 47 | maximum_checkpoint_manager_blocking_time: float = 0.0 48 | standard_deviation_checkpoint_manager_blocking_time: float = 0.0 49 | 50 | total_checkpointer_blocking_time: float = 0.0 51 | average_checkpointer_blocking_time: float = 0.0 52 | minimum_checkpointer_blocking_time: float = 0.0 53 | maximum_checkpointer_blocking_time: float = 0.0 54 | standard_deviation_checkpointer_blocking_time: float = 0.0 55 | 56 | total_wait_for_prev_time: float = 0.0 57 | average_wait_for_prev_time: float = 0.0 58 | minimum_wait_for_prev_time: float = 0.0 59 | maximum_wait_for_prev_time: float = 0.0 60 | standard_deviation_wait_for_prev_time: float = 0.0 61 | 62 | total_get_old_steps_time: float = 0.0 63 | average_get_old_steps_time: float = 0.0 64 | minimum_get_old_steps_time: float = 0.0 65 | maximum_get_old_steps_time: float = 0.0 66 | standard_deviation_get_old_steps_time: float = 0.0 67 | 68 | 69 | @dataclasses.dataclass 70 | class RestoreCheckpointManagerVerticalStepStats: 71 | """Vertical step statistics for restore operation.""" 72 | total_checkpoint_manager_time: float = 0.0 73 | average_checkpoint_manager_time: float = 0.0 74 | minimum_checkpoint_manager_time: float = 0.0 75 | maximum_checkpoint_manager_time: float = 0.0 76 | standard_deviation_checkpoint_manager_time: float = 0.0 77 | 78 | total_restore_time: float = 0.0 79 | average_restore_time: float = 0.0 80 | minimum_restore_time: float = 0.0 81 | maximum_restore_time: float = 0.0 82 | standard_deviation_restore_time: float = 0.0 83 | 84 | total_broadcast_time: float = 0.0 85 | average_broadcast_time: float = 0.0 86 | minimum_broadcast_time: float = 0.0 87 | maximum_broadcast_time: float = 0.0 88 | standard_deviation_broadcast_time: float = 0.0 89 | 90 | 91 | @dataclasses.dataclass 92 | class SaveProcessedStep: 93 | """Horizontal save step stats for a processed step.""" 94 | step: str = '' 95 | total_checkpoint_manager_blocking_time: float = 0.0 96 | total_checkpointer_blocking_time: float = 0.0 97 | total_wait_for_prev_time: float = 0.0 98 | total_get_old_steps_time: float = 0.0 99 | occurrence: int = 0 100 | 101 | 102 | @dataclasses.dataclass 103 | class RestoreProcessedStep: 104 | """Horizontal restore step stats for a processed step.""" 105 | step: str = '' 106 | total_checkpoint_manager_time: float = 0.0 107 | total_restore_time: float = 0.0 108 | total_broadcast_time: float = 0.0 109 | broadcast_occurrence: int = 0 110 | occurrence: int = 0 111 | 112 | 113 | @dataclasses.dataclass 114 | class CheckpointLoggerOptions: 115 | """Checkpoint logger options.""" 116 | job_name: str = _JOB_NAME 117 | logger_name: str = _LOGGER_NAME 118 | client: Optional[google_cloud_logging.Client] = None 119 | use_goodput_logger: bool = False 120 | 121 | 122 | class CheckpointBadputCalculator: 123 | """Checkpoint Badput Calculator class.""" 124 | 125 | def __init__( 126 | self, options: CheckpointLoggerOptions = CheckpointLoggerOptions() 127 | ): 128 | self._options = options 129 | if not options.use_goodput_logger: 130 | if options.client is None: 131 | self.logging_client = google_cloud_logging.Client() 132 | else: 133 | self.logging_client = options.client 134 | self._logger = self.logging_client.logger(options.logger_name) 135 | self._use_goodput_logger = options.use_goodput_logger 136 | self.entries = [] 137 | 138 | def read_entries(self) -> List[Dict[str, str]]: 139 | """Queries Cloud Logging entries for the specific job. 140 | 141 | Returns: 142 | Filtered entries in ascending order of timestamp. 143 | """ 144 | if self._use_goodput_logger: 145 | return self.entries 146 | 147 | filter_entries = [ 148 | 'severity=INFO', 149 | f'jsonPayload.job_name="{self._options.job_name}"', 150 | ] 151 | 152 | event_type_filter = ( 153 | '(jsonPayload.event_type=save OR jsonPayload.event_type=restore OR' 154 | ' jsonPayload.event_type=emergency_restore)' 155 | ) 156 | filter_entries.append(event_type_filter) 157 | 158 | filter_entries = ' AND '.join(filter_entries) 159 | 160 | entries = self._logger.list_entries( 161 | filter_=filter_entries, 162 | order_by=google_cloud_logging.ASCENDING, 163 | page_size=_CLOUD_LOGGING_PAGE_SIZE, 164 | ) 165 | entry_payload = [entry.payload for entry in entries] 166 | return entry_payload 167 | 168 | def _is_local_operation(self, step_stats: Dict[str, str]): 169 | if (step_stats[_DIRECTORY]).startswith('gs://'): 170 | return False 171 | else: 172 | return True 173 | 174 | def is_valid_save_stats( 175 | self, 176 | step_stats: Dict[str, str], 177 | operation_type: Optional[str] = OPERATION_TYPE_PERSISTENT_AND_LOCAL, 178 | ): 179 | """Checks if the step stats is valid. 180 | 181 | Args: 182 | step_stats: The step stats to check. 183 | operation_type: whether to check for local or persistent or both. 184 | 185 | Returns: 186 | Boolean indicating whether the step stats is valid. 187 | """ 188 | if ( 189 | _EVENT_TYPE not in step_stats 190 | or step_stats[_EVENT_TYPE] != OPERATION_TYPE_SAVE 191 | ): 192 | return False 193 | elif operation_type == OPERATION_TYPE_LOCAL: 194 | return self._is_local_operation(step_stats) 195 | elif operation_type == OPERATION_TYPE_PERSISTENT: 196 | return not self._is_local_operation(step_stats) 197 | else: 198 | return True 199 | 200 | def is_valid_restore_stats( 201 | self, 202 | step_stats: Dict[str, str], 203 | operation_type: Optional[str] = OPERATION_TYPE_PERSISTENT_AND_LOCAL, 204 | ): 205 | """Checks if the step stats is valid. 206 | 207 | Args: 208 | step_stats: The step stats to check. 209 | operation_type: whether to check for local or persistent or both. 210 | 211 | Returns: 212 | Boolean indicating whether the step stats is valid. 213 | 214 | """ 215 | if _EVENT_TYPE not in step_stats: 216 | return False 217 | elif step_stats[_EVENT_TYPE] not in [ 218 | OPERATION_TYPE_RESTORE, 219 | OPERATION_TYPE_EMERGENCY_RESTORE, 220 | ]: 221 | return False 222 | elif operation_type == OPERATION_TYPE_LOCAL: 223 | return step_stats[_EVENT_TYPE] == OPERATION_TYPE_EMERGENCY_RESTORE 224 | elif operation_type == OPERATION_TYPE_PERSISTENT: 225 | return step_stats[_EVENT_TYPE] == OPERATION_TYPE_RESTORE 226 | else: 227 | return True 228 | 229 | def _save_statistics( 230 | self, processed_step_stats: Dict[str, SaveProcessedStep] 231 | ) -> SaveCheckpointManagerVerticalStepStats: 232 | """Gets the processed step stats.""" 233 | if not processed_step_stats: 234 | return SaveCheckpointManagerVerticalStepStats() 235 | 236 | for _, stats in processed_step_stats.items(): 237 | if stats.occurrence > 0: 238 | stats.total_checkpoint_manager_blocking_time = ( 239 | stats.total_checkpoint_manager_blocking_time / stats.occurrence 240 | ) 241 | stats.total_checkpointer_blocking_time = ( 242 | stats.total_checkpointer_blocking_time / stats.occurrence 243 | ) 244 | stats.total_wait_for_prev_time = ( 245 | stats.total_wait_for_prev_time / stats.occurrence 246 | ) 247 | stats.total_get_old_steps_time = ( 248 | stats.total_get_old_steps_time / stats.occurrence 249 | ) 250 | 251 | vertical_step_stats = SaveCheckpointManagerVerticalStepStats() 252 | 253 | # Record statistics for checkpoint_manager_blocking_time. 254 | vertical_step_stats.total_checkpoint_manager_blocking_time = sum( 255 | map( 256 | lambda stats: stats.total_checkpoint_manager_blocking_time, 257 | processed_step_stats.values(), 258 | ) 259 | ) 260 | vertical_step_stats.average_checkpoint_manager_blocking_time = ( 261 | vertical_step_stats.total_checkpoint_manager_blocking_time 262 | / len(processed_step_stats) 263 | ) 264 | vertical_step_stats.minimum_checkpoint_manager_blocking_time = min( 265 | map( 266 | lambda stats: stats.total_checkpoint_manager_blocking_time, 267 | processed_step_stats.values(), 268 | ) 269 | ) 270 | vertical_step_stats.maximum_checkpoint_manager_blocking_time = max( 271 | map( 272 | lambda stats: stats.total_checkpoint_manager_blocking_time, 273 | processed_step_stats.values(), 274 | ) 275 | ) 276 | if len(processed_step_stats) > 1: 277 | vertical_step_stats.standard_deviation_checkpoint_manager_blocking_time = ( 278 | statistics.stdev( 279 | map( 280 | lambda stats: stats.total_checkpoint_manager_blocking_time, 281 | processed_step_stats.values(), 282 | ) 283 | ) 284 | ) 285 | 286 | # Record statistics for checkpointer_blocking_time. 287 | vertical_step_stats.total_checkpointer_blocking_time = sum( 288 | map( 289 | lambda stats: stats.total_checkpointer_blocking_time, 290 | processed_step_stats.values(), 291 | ) 292 | ) 293 | vertical_step_stats.average_checkpointer_blocking_time = ( 294 | vertical_step_stats.total_checkpointer_blocking_time 295 | / len(processed_step_stats) 296 | ) 297 | vertical_step_stats.minimum_checkpointer_blocking_time = min( 298 | map( 299 | lambda stats: stats.total_checkpointer_blocking_time, 300 | processed_step_stats.values(), 301 | ) 302 | ) 303 | vertical_step_stats.maximum_checkpointer_blocking_time = max( 304 | map( 305 | lambda stats: stats.total_checkpointer_blocking_time, 306 | processed_step_stats.values(), 307 | ) 308 | ) 309 | if len(processed_step_stats) > 1: 310 | vertical_step_stats.standard_deviation_checkpointer_blocking_time = ( 311 | statistics.stdev( 312 | map( 313 | lambda stats: stats.total_checkpointer_blocking_time, 314 | processed_step_stats.values(), 315 | ) 316 | ) 317 | ) 318 | 319 | # Record statistics for wait_for_prev_time. 320 | vertical_step_stats.total_wait_for_prev_time = sum( 321 | map( 322 | lambda stats: stats.total_wait_for_prev_time, 323 | processed_step_stats.values(), 324 | ) 325 | ) 326 | vertical_step_stats.average_wait_for_prev_time = ( 327 | vertical_step_stats.total_wait_for_prev_time 328 | / len(processed_step_stats) 329 | ) 330 | vertical_step_stats.minimum_wait_for_prev_time = min( 331 | map( 332 | lambda stats: stats.total_wait_for_prev_time, 333 | processed_step_stats.values(), 334 | ) 335 | ) 336 | vertical_step_stats.maximum_wait_for_prev_time = max( 337 | map( 338 | lambda stats: stats.total_wait_for_prev_time, 339 | processed_step_stats.values(), 340 | ) 341 | ) 342 | if len(processed_step_stats) > 1: 343 | vertical_step_stats.standard_deviation_wait_for_prev_time = ( 344 | statistics.stdev( 345 | map( 346 | lambda stats: stats.total_wait_for_prev_time, 347 | processed_step_stats.values(), 348 | ) 349 | ) 350 | ) 351 | 352 | # Record statistics for get_old_steps_time. 353 | vertical_step_stats.total_get_old_steps_time = sum( 354 | map( 355 | lambda stats: stats.total_get_old_steps_time, 356 | processed_step_stats.values(), 357 | ) 358 | ) 359 | vertical_step_stats.average_get_old_steps_time = ( 360 | vertical_step_stats.total_get_old_steps_time / len(processed_step_stats) 361 | ) 362 | vertical_step_stats.minimum_get_old_steps_time = min( 363 | map( 364 | lambda stats: stats.total_get_old_steps_time, 365 | processed_step_stats.values(), 366 | ) 367 | ) 368 | vertical_step_stats.maximum_get_old_steps_time = max( 369 | map( 370 | lambda stats: stats.total_get_old_steps_time, 371 | processed_step_stats.values(), 372 | ) 373 | ) 374 | if len(processed_step_stats) > 1: 375 | vertical_step_stats.standard_deviation_get_old_steps_time = ( 376 | statistics.stdev( 377 | map( 378 | lambda stats: stats.total_get_old_steps_time, 379 | processed_step_stats.values(), 380 | ) 381 | ) 382 | ) 383 | return vertical_step_stats 384 | 385 | def calculate_save_operation_checkpoint_manager_blocking_time( 386 | self, operation_type: Optional[str] = OPERATION_TYPE_PERSISTENT_AND_LOCAL, 387 | ) -> SaveCheckpointManagerVerticalStepStats: 388 | """Gets checkpoint manager blocking time breakdown for save operation.""" 389 | self.entries = self.read_entries() 390 | 391 | step_already_processed: dict[str, SaveProcessedStep] = dict() 392 | for step_stats in self.entries: 393 | if ( 394 | not self.is_valid_save_stats(step_stats, operation_type) 395 | ): 396 | continue 397 | 398 | # Create a step info to identify the step_statistics whether local or 399 | # persistent. 400 | if self._is_local_operation(step_stats): 401 | step_info = str(step_stats[_STEP]) + '-' + OPERATION_TYPE_LOCAL 402 | else: 403 | step_info = ( 404 | str(step_stats[_STEP]) + '-' + OPERATION_TYPE_PERSISTENT 405 | ) 406 | if step_already_processed.get(step_info) is None: 407 | step_already_processed[step_info] = SaveProcessedStep() 408 | step_already_processed[step_info].step = step_info 409 | step_already_processed[ 410 | step_info 411 | ].total_checkpoint_manager_blocking_time = float( 412 | step_stats[_CHECKPOINT_MANAGER_SAVE_DURATION_SECS] 413 | ) 414 | step_already_processed[step_info].total_checkpointer_blocking_time = ( 415 | float(step_stats[_CHECKPOINTER_SAVE_DURATION_SECS]) 416 | ) 417 | step_already_processed[step_info].total_wait_for_prev_time = float( 418 | step_stats[_WAIT_FOR_PREV_DURATION_SECS] 419 | ) 420 | step_already_processed[step_info].total_get_old_steps_time = float( 421 | step_stats[_GET_OLD_STEPS_DURATION_SECS] 422 | ) 423 | step_already_processed[step_info].occurrence = 1 424 | else: 425 | step_already_processed[step_info].step = step_info 426 | step_already_processed[ 427 | step_info 428 | ].total_checkpoint_manager_blocking_time += float( 429 | step_stats[_CHECKPOINT_MANAGER_SAVE_DURATION_SECS] 430 | ) 431 | step_already_processed[ 432 | step_info 433 | ].total_checkpointer_blocking_time += float( 434 | step_stats[_CHECKPOINTER_SAVE_DURATION_SECS] 435 | ) 436 | step_already_processed[step_info].total_wait_for_prev_time += float( 437 | step_stats[_WAIT_FOR_PREV_DURATION_SECS] 438 | ) 439 | step_already_processed[step_info].total_get_old_steps_time += float( 440 | step_stats[_GET_OLD_STEPS_DURATION_SECS] 441 | ) 442 | step_already_processed[step_info].occurrence += 1 443 | 444 | # Calculate the vertical step stats for the checkpoint manager blocking 445 | # time. 446 | save_statistics = self._save_statistics( 447 | step_already_processed 448 | ) 449 | 450 | return save_statistics 451 | 452 | def _restore_statistics( 453 | self, processed_step_stats: Dict[str, RestoreProcessedStep] 454 | ) -> RestoreCheckpointManagerVerticalStepStats: 455 | """Calculates the vertical step stats.""" 456 | if not processed_step_stats: 457 | return RestoreCheckpointManagerVerticalStepStats() 458 | broadcast_occurrence = 0 459 | for _, stats in processed_step_stats.items(): 460 | stats.total_checkpoint_manager_time = ( 461 | stats.total_checkpoint_manager_time / stats.occurrence 462 | ) 463 | stats.total_restore_time = stats.total_restore_time / stats.occurrence 464 | if stats.broadcast_occurrence > 0: 465 | stats.total_broadcast_time = ( 466 | stats.total_broadcast_time / stats.broadcast_occurrence 467 | ) 468 | broadcast_occurrence += 1 469 | 470 | vertical_step_stats = RestoreCheckpointManagerVerticalStepStats() 471 | 472 | # Record statistics for total time checkpoint manager spent on restore. 473 | vertical_step_stats.total_checkpoint_manager_time = sum( 474 | map( 475 | lambda stats: stats.total_checkpoint_manager_time, 476 | processed_step_stats.values(), 477 | ) 478 | ) 479 | vertical_step_stats.average_checkpoint_manager_time = ( 480 | vertical_step_stats.total_checkpoint_manager_time 481 | / len(processed_step_stats) 482 | ) 483 | vertical_step_stats.minimum_checkpoint_manager_time = min( 484 | map( 485 | lambda stats: stats.total_checkpoint_manager_time, 486 | processed_step_stats.values(), 487 | ) 488 | ) 489 | vertical_step_stats.maximum_checkpoint_manager_time = max( 490 | map( 491 | lambda stats: stats.total_checkpoint_manager_time, 492 | processed_step_stats.values(), 493 | ) 494 | ) 495 | if len(processed_step_stats) > 1: 496 | vertical_step_stats.standard_deviation_checkpoint_manager_time = ( 497 | statistics.stdev( 498 | map( 499 | lambda stats: stats.total_checkpoint_manager_time, 500 | processed_step_stats.values(), 501 | ) 502 | ) 503 | ) 504 | # Record statistics for restore time. 505 | vertical_step_stats.total_restore_time = sum( 506 | map( 507 | lambda stats: stats.total_restore_time, 508 | processed_step_stats.values(), 509 | ) 510 | ) 511 | vertical_step_stats.average_restore_time = ( 512 | vertical_step_stats.total_restore_time / len(processed_step_stats) 513 | ) 514 | vertical_step_stats.minimum_restore_time = min( 515 | map( 516 | lambda stats: stats.total_restore_time, 517 | processed_step_stats.values(), 518 | ) 519 | ) 520 | vertical_step_stats.maximum_restore_time = max( 521 | map( 522 | lambda stats: stats.total_restore_time, 523 | processed_step_stats.values(), 524 | ) 525 | ) 526 | if len(processed_step_stats) > 1: 527 | vertical_step_stats.standard_deviation_restore_time = ( 528 | statistics.stdev( 529 | map( 530 | lambda stats: stats.total_restore_time, 531 | processed_step_stats.values(), 532 | ) 533 | ) 534 | ) 535 | 536 | # Record statistics for broadcasting the restored checkpoint(Emergency 537 | # restore only). 538 | if broadcast_occurrence > 0: 539 | vertical_step_stats.total_broadcast_time = sum( 540 | map( 541 | lambda stats: stats.total_broadcast_time, 542 | processed_step_stats.values(), 543 | ) 544 | ) 545 | vertical_step_stats.average_broadcast_time = ( 546 | vertical_step_stats.total_broadcast_time / broadcast_occurrence 547 | ) 548 | vertical_step_stats.minimum_broadcast_time = min( 549 | map( 550 | lambda stats: stats.total_broadcast_time, 551 | processed_step_stats.values(), 552 | ) 553 | ) 554 | vertical_step_stats.maximum_broadcast_time = max( 555 | map( 556 | lambda stats: stats.total_broadcast_time, 557 | processed_step_stats.values(), 558 | ) 559 | ) 560 | if len(processed_step_stats) > 1: 561 | vertical_step_stats.standard_deviation_broadcast_time = ( 562 | statistics.stdev( 563 | map( 564 | lambda stats: stats.total_broadcast_time, 565 | processed_step_stats.values(), 566 | ) 567 | ) 568 | ) 569 | 570 | return vertical_step_stats 571 | 572 | def calculate_restore_operation_checkpoint_manager_blocking_time( 573 | self, 574 | operation_type: Optional[str] = OPERATION_TYPE_PERSISTENT_AND_LOCAL, 575 | ) -> RestoreCheckpointManagerVerticalStepStats: 576 | """Gets checkpoint manager blocking time breakdown for restore operation.""" 577 | self.entries = self.read_entries() 578 | 579 | step_already_processed: dict[str, RestoreProcessedStep] = dict() 580 | for step_stats in self.entries: 581 | if not self.is_valid_restore_stats(step_stats, operation_type): 582 | continue 583 | 584 | # Create a step info to identify the step_stats whether local or 585 | if self._is_local_operation(step_stats): 586 | step_info = str(step_stats[_STEP]) + '-' + OPERATION_TYPE_LOCAL 587 | else: 588 | step_info = str(step_stats[_STEP]) + '-' + OPERATION_TYPE_PERSISTENT 589 | 590 | if step_already_processed.get(step_info) is None: 591 | step_already_processed[step_info] = RestoreProcessedStep() 592 | step_already_processed[step_info].step = step_info 593 | 594 | step_already_processed[step_info].total_checkpoint_manager_time = float( 595 | step_stats[_CHECKPOINT_MANAGER_RESTORE_DURATION_SECS] 596 | ) 597 | step_already_processed[step_info].total_restore_time = float( 598 | step_stats[_CHECKPOINTER_RESTORE_DURATION_SECS] 599 | ) 600 | if ( 601 | step_stats.get(_BROADCAST_DURATION_SECS) 602 | and step_stats[_BROADCAST_DURATION_SECS] is not None 603 | ): 604 | step_already_processed[step_info].total_broadcast_time = float( 605 | step_stats[_BROADCAST_DURATION_SECS] 606 | ) 607 | step_already_processed[step_info].broadcast_occurrence = 1 608 | step_already_processed[step_info].occurrence = 1 609 | else: 610 | step_already_processed[step_info].step = step_info 611 | step_already_processed[ 612 | step_info 613 | ].total_checkpoint_manager_time += float( 614 | step_stats[_CHECKPOINT_MANAGER_RESTORE_DURATION_SECS] 615 | ) 616 | step_already_processed[step_info].total_restore_time += float( 617 | step_stats[_CHECKPOINTER_RESTORE_DURATION_SECS] 618 | ) 619 | if ( 620 | step_stats.get(_BROADCAST_DURATION_SECS) 621 | and step_stats[_BROADCAST_DURATION_SECS] is not None 622 | ): 623 | step_already_processed[step_info].total_broadcast_time += float( 624 | step_stats[_BROADCAST_DURATION_SECS] 625 | ) 626 | step_already_processed[step_info].broadcast_occurrence += 1 627 | step_already_processed[step_info].occurrence += 1 628 | 629 | # Calculate the vertical step stats for the checkpoint manager blocking 630 | # time. 631 | restore_statistics = self._restore_statistics(step_already_processed) 632 | 633 | return restore_statistics 634 | 635 | if __name__ == '__main__': 636 | parser = argparse.ArgumentParser() 637 | options = CheckpointLoggerOptions() 638 | parser.add_argument( 639 | '--job_name', 640 | type=str, 641 | default=options.job_name, 642 | help='The name of the job.', 643 | ) 644 | parser.add_argument( 645 | '--logger_name', 646 | type=str, 647 | default=options.logger_name, 648 | help='The name of the logger.', 649 | ) 650 | parser.add_argument( 651 | '--client', 652 | type=str, 653 | default=options.client, 654 | help='The name of the client.', 655 | ) 656 | parser.add_argument( 657 | '--operation_type', 658 | type=str, 659 | default=OPERATION_TYPE_PERSISTENT_AND_LOCAL, 660 | help='The operation type.', 661 | ) 662 | args = parser.parse_args() 663 | options = CheckpointLoggerOptions( 664 | job_name=args.job_name, 665 | logger_name=args.logger_name, 666 | client=args.client, 667 | ) 668 | checkpoint_badput_calculator = ( 669 | CheckpointBadputCalculator(options) 670 | ) 671 | checkpoint_badput_calculator.calculate_save_operation_checkpoint_manager_blocking_time( 672 | args.operation_type 673 | ) 674 | 675 | 676 | 677 | -------------------------------------------------------------------------------- /ml_goodput_measurement/src/gcp_metrics.py: -------------------------------------------------------------------------------- 1 | """A generic class to send multiple metrics to GCP Cloud Monitoring in a batch with dynamic resources.""" 2 | import enum 3 | import logging 4 | import time 5 | from typing import Any, Dict 6 | 7 | from google.api_core import exceptions 8 | from google.cloud import monitoring_v3 9 | 10 | GoogleAPIError = exceptions.GoogleAPIError 11 | Enum = enum.Enum 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class ValueType(Enum): 17 | """Enum for metric value types.""" 18 | 19 | BOOL = "bool_value" 20 | INT = "int64_value" 21 | DOUBLE = "double_value" 22 | STRING = "string_value" 23 | DISTRIBUTION = "distribution_value" # Add other types as needed 24 | 25 | 26 | class GCPMetrics: 27 | """A generic class to send multiple metrics to GCP Cloud Monitoring in a batch with dynamic resources.""" 28 | 29 | def __init__(self, project_id: str): 30 | """Initializes the GCPMetrics.""" 31 | self.project_id = project_id 32 | self.client = monitoring_v3.MetricServiceClient() 33 | self.project_name = f"projects/{project_id}" 34 | 35 | def create_time_series( 36 | self, 37 | metric_type: str, 38 | value, 39 | value_type: ValueType, 40 | metric_labels: Dict[str, str], 41 | resource_type: str, 42 | resource_labels: Dict[str, str], 43 | seconds: int, 44 | nanos: int, 45 | ) -> monitoring_v3.TimeSeries: 46 | """Creates a TimeSeries object for a single metric with dynamic resources.""" 47 | series = monitoring_v3.TimeSeries() 48 | series.metric.type = metric_type 49 | series.resource.type = resource_type 50 | series.resource.labels.update(resource_labels) 51 | if metric_labels: 52 | series.metric.labels.update(metric_labels) 53 | 54 | point = monitoring_v3.Point( 55 | interval=monitoring_v3.TimeInterval( 56 | end_time={"seconds": seconds, "nanos": nanos} 57 | ), 58 | value=monitoring_v3.TypedValue(**{value_type.value: value}), 59 | ) 60 | series.points.append(point) 61 | 62 | return series 63 | 64 | def send_metrics(self, metrics: list[Dict[str, Any]]): 65 | """Sends multiple metrics to GCP Monitoring in a batch with dynamic resources. 66 | 67 | Args: 68 | metrics: A list of dictionaries, where each dictionary represents 69 | a metric. Each dictionary should have the following keys: 70 | - 'metric_type': str 71 | - 'value': The metric value. 72 | - 'value_type': ValueType (e.g., ValueType.INT, 73 | ValueType.DOUBLE) 74 | - 'metric_labels': dict (optional) 75 | - 'resource_type': str 76 | - 'resource_labels': dict 77 | """ 78 | try: 79 | now = time.time() 80 | seconds = int(now) 81 | nanos = int((now - seconds) * 10**9) 82 | 83 | time_series_list = [] 84 | for metric in metrics: 85 | try: 86 | metric_labels = metric.get("metric_labels", {}) 87 | series = self.create_time_series( 88 | metric["metric_type"], 89 | metric["value"], 90 | metric["value_type"], 91 | metric_labels, 92 | metric["resource_type"], 93 | metric["resource_labels"], 94 | seconds, 95 | nanos, 96 | ) 97 | time_series_list.append(series) 98 | except Exception as e: # pylint: disable=broad-exception-caught 99 | logger.error("Failed to create time series: %s", e) 100 | self.client.create_time_series( 101 | name=self.project_name, time_series=time_series_list 102 | ) 103 | logger.info("Sent %d metrics to GCP Monitoring.", len(metrics)) 104 | 105 | except GoogleAPIError as e: 106 | logger.error("Failed to send metrics: %s", e) 107 | -------------------------------------------------------------------------------- /ml_goodput_measurement/src/goodput_cache.py: -------------------------------------------------------------------------------- 1 | """Goodput Cache implementations.""" 2 | 3 | import datetime 4 | from typing import Any 5 | 6 | from cloud_goodput.ml_goodput_measurement.src import goodput_utils 7 | 8 | 9 | StepInfo = goodput_utils.StepInfo 10 | GoodputInfo = goodput_utils.GoodputInfo 11 | _JOB_START_TIME = 'job_start_time' 12 | _JOB_END_TIME = 'job_end_time' 13 | _STEP_START_TIME = 'step_start_time' 14 | 15 | 16 | class GoodputCache: 17 | """Goodput Cache.""" 18 | 19 | def __init__(self): 20 | self._cached_entries = [] 21 | self._step_entries = [] 22 | self._goodput_info = None 23 | self._job_start_time = None 24 | self._job_end_time = None 25 | self._step_info = None 26 | self._last_entry_time = None 27 | 28 | def update_step_info(self, step_info: StepInfo): 29 | """Updates the step information.""" 30 | self._step_info = step_info 31 | 32 | def update_cached_entries(self, entries: list[Any]): 33 | """Updated the cached entries.""" 34 | self._cached_entries.extend(entries) 35 | self.update_last_entry_time() 36 | self.update_job_start_time() 37 | self.update_job_end_time() 38 | new_step_entries = [entry for entry in entries if _STEP_START_TIME in entry] 39 | self._step_entries.extend(new_step_entries) 40 | 41 | def update_last_entry_time(self): 42 | """Helper function to store the timestamp of the last entry in the cache.""" 43 | if self._cached_entries: 44 | last_entry = self._cached_entries[-1] 45 | entry_time = goodput_utils.get_entry_time_from_log_entry(last_entry) 46 | if entry_time: 47 | self._last_entry_time = entry_time 48 | 49 | def update_job_start_time(self): 50 | """Updates the job start time.""" 51 | # If the job start time is not set, try to find it in the cached entries. 52 | if self._job_start_time is None and self._cached_entries: 53 | for entry in self._cached_entries: 54 | if _JOB_START_TIME in entry: 55 | self._job_start_time = datetime.datetime.fromtimestamp( 56 | entry[_JOB_START_TIME], tz=datetime.timezone.utc 57 | ) 58 | break 59 | 60 | def update_job_end_time(self): 61 | """Updates the job end time.""" 62 | # Overwrite the latest job end time if cached entries contain the job end 63 | # time. 64 | if self._job_end_time is None and self._cached_entries: 65 | for entry in reversed(self._cached_entries): 66 | if _JOB_END_TIME in entry: 67 | self._job_end_time = datetime.datetime.fromtimestamp( 68 | entry[_JOB_END_TIME], tz=datetime.timezone.utc 69 | ) 70 | break 71 | 72 | def update_goodput_info(self, goodput_info: GoodputInfo): 73 | """Updates the last computed Goodput information.""" 74 | self._goodput_info = goodput_info 75 | 76 | def get_cached_entries(self): 77 | """Returns the cached entries.""" 78 | return self._cached_entries 79 | 80 | def get_step_entries(self): 81 | """Returns the step entries.""" 82 | return self._step_entries 83 | 84 | def get_goodput_info(self): 85 | """Returns the last computed Goodput information.""" 86 | return self._goodput_info 87 | 88 | def get_job_start_time(self): 89 | """Returns the job start time.""" 90 | return self._job_start_time 91 | 92 | def get_job_end_time(self): 93 | """Returns the job end time.""" 94 | return self._job_end_time 95 | 96 | def get_last_entry_time(self): 97 | """Returns the last entry time.""" 98 | return self._last_entry_time 99 | 100 | def get_step_info(self): 101 | """Returns the step information.""" 102 | return self._step_info 103 | 104 | def clear_cache(self): 105 | """Clears the cache.""" 106 | self._cached_entries = [] 107 | self._goodput_info = None 108 | self._last_entry_time = None 109 | 110 | def is_cache_empty(self) -> bool: 111 | """Checks if the cache is empty.""" 112 | return not self._cached_entries 113 | -------------------------------------------------------------------------------- /ml_goodput_measurement/src/goodput_utils.py: -------------------------------------------------------------------------------- 1 | """Goodput Utility Classes and Helpers.""" 2 | 3 | import dataclasses 4 | import datetime 5 | import enum 6 | import logging 7 | from typing import Any, Optional, TypedDict 8 | 9 | import numpy as np 10 | import requests 11 | from scipy import stats 12 | from urllib3.util import retry 13 | 14 | 15 | Retry = retry.Retry 16 | _TIME_ENTRY = 'time' 17 | _METADATA_SERVER_URL = 'http://metadata.google.internal/computeMetadata/v1/' 18 | _METADATA_HEADERS = {'Metadata-Flavor': 'Google'} 19 | 20 | MACHINE_TYPE_TO_ACCELERATOR_TYPE_MAPPING = { 21 | 'ct6e': 'TPU-v6e', 22 | 'ct5p': 'TPU-v5p', 23 | 'ct5lp': 'TPU-v5e', 24 | 'ct5l': 'TPU-v5e', 25 | 'ct4p': 'TPU-v4p', 26 | 'ct3p': 'TPU-v3', 27 | 'ct3': 'TPU-v3', 28 | 'tpu-v2': 'TPU-v2', 29 | 'tpu': 'TPU', 30 | 'a3-edgegpu': 'NVIDIA-H100', 31 | 'a3-highgpu': 'NVIDIA-H100', 32 | 'a3-megagpu': 'NVIDIA-H100', 33 | 'a3-ultragpu': 'NVIDIA-H200', 34 | 'a2': 'NVIDIA-A100', 35 | 'gpu': 'GPU', 36 | } 37 | 38 | 39 | @dataclasses.dataclass 40 | class GCPOptions: 41 | project_id: Optional[str] = None 42 | location: Optional[str] = None 43 | replica_id: str = '0' 44 | acc_type: Optional[str] = None 45 | enable_gcp_goodput_metrics: bool = True 46 | enable_gcp_step_deviation_metrics: bool = True 47 | 48 | 49 | @dataclasses.dataclass 50 | class EntryTime: 51 | field_name: str 52 | timestamp: float 53 | 54 | 55 | # Cumulative metric types for upload and monitoring. 56 | class MetricType(enum.Enum): 57 | """The type of CUMULATIVE Metric.""" 58 | GOODPUT_TIME = 'goodput_time' 59 | BADPUT_TIME = 'badput_time' 60 | MAX_PRODUCTIVE_STEP = 'max_productive_step' 61 | TOTAL_ELAPSED_TIME = 'total_elapsed_time' 62 | DISRUPTION_COUNT = 'disruption_count' 63 | STEP_TIME_DEVIATION = 'step_time_deviation' 64 | IDEAL_STEP_TIME = 'ideal_step_time' 65 | 66 | 67 | # Interval metric types for upload and monitoring. 68 | class IntervalMetricType(enum.Enum): 69 | """The type of INTERVAL Metric.""" 70 | 71 | INTERVAL_GOODPUT = 'interval_goodput' 72 | INTERVAL_BADPUT = 'interval_badput' 73 | INTERVAL_SIZE = 'interval_size' 74 | 75 | 76 | # Productive time is not broken down by activities yet. As such, we only have 77 | # one type of Goodput which contributes to the total productive time. 78 | class GoodputType(enum.Enum): 79 | """The type of Goodput.""" 80 | 81 | TOTAL = 1 82 | 83 | 84 | class BadputType(enum.Enum): 85 | """The type of Badput.""" 86 | 87 | TPU_INITIALIZATION = 1 88 | TRAINING_PREP = 2 89 | PROGRAM_STARTUP = 3 90 | DATA_LOADING_SYNC = 4 91 | DATA_LOADING_ASYNC = 5 92 | UNPRODUCTIVE_CHECKPOINT_SAVE_TIME = 6 93 | UNPRODUCTIVE_CHECKPOINT_RESTORE_TIME = 7 94 | WASTED_PROGRESS_FROM_DISRUPTION = 8 95 | CUSTOM_BADPUT_EVENTS = 9 96 | OTHER = 10 97 | 98 | 99 | class WorkloadMetricDetails(TypedDict): 100 | goodput_time: dict[GoodputType, float] 101 | badput_time: dict[BadputType, float | dict[str, float]] 102 | max_productive_step: int 103 | total_elapsed_time: float 104 | disruption_count: int 105 | step_time_deviation: dict[int, float] 106 | ideal_step_time: float 107 | 108 | 109 | class IntervalWorkloadMetricDetails(TypedDict): 110 | interval_goodput: dict[GoodputType, float] 111 | interval_badput: dict[BadputType, float | dict[str, float]] 112 | interval_size: int # Unit: seconds. 113 | 114 | 115 | ACTIVITY_EXCLUSION_LIST = [ 116 | # DATA_LOADING_ASYNC is not a non-productive activity as it is not 117 | # blocking. Hence, we exclude it from calculating Goodput. 118 | 'DATA_LOADING_ASYNC', 119 | ] 120 | 121 | 122 | class MonitoringWindowType(enum.Enum): 123 | """The type of Monitoring Window.""" 124 | 125 | CUMULATIVE = 'cumulative' 126 | INTERVAL = 'interval' 127 | 128 | 129 | _DEFAULT_RECENT_WINDOW_SIZE = 100 130 | _DEFAULT_BASELINE_WINDOW_SIZE = 1000 131 | _DEFAULT_SPIKE_PERCENTILE = 90 132 | 133 | 134 | class GoodputInfo: 135 | """Goodput Information.""" 136 | 137 | def __init__( 138 | self, 139 | total_productive_time: float = 0.0, 140 | total_elapsed_time: float = 0.0, 141 | total_unproductive_time: Optional[dict[BadputType, float]] = None, 142 | max_productive_step: int = 0, 143 | last_recorded_step: int = 0, 144 | last_updated_timestamp: datetime.datetime = datetime.datetime.now( 145 | datetime.timezone.utc 146 | ), 147 | number_of_disruptions: int = 0, 148 | ): 149 | self.total_productive_time = total_productive_time 150 | self.total_elapsed_time = total_elapsed_time 151 | 152 | # We cannot use {} as the default argument directly because it's a mutable 153 | # default argument. Mutable default arguments are shared between all 154 | # instances of the class. If one instance modifies the default 155 | # dictionary, it will affect all other instances. Instead, we use 156 | # None as a sentinel value and create a new dictionary inside the 157 | # __init__ method if no dictionary is provided. This ensures each 158 | # instance gets its own dictionary. 159 | self.total_unproductive_time = ( 160 | total_unproductive_time or {} 161 | ) 162 | self.max_productive_step = max_productive_step 163 | self.last_recorded_step = last_recorded_step 164 | self.last_updated_timestamp = last_updated_timestamp 165 | self.number_of_disruptions = number_of_disruptions 166 | 167 | 168 | class StepInfo: 169 | """Step Information.""" 170 | 171 | def __init__( 172 | self, 173 | ideal_step_time: float, 174 | step_deviations: dict[int, float], 175 | ): 176 | self.ideal_step_time = ideal_step_time 177 | self.step_deviations = step_deviations 178 | 179 | 180 | def compute_percentile(values: list[float], percentile: float) -> float: 181 | """Computes the specified percentile value from a list of floats.""" 182 | if not values: 183 | return 0.0 184 | 185 | sorted_values = sorted(values) 186 | index = (len(sorted_values) - 1) * (percentile / 100.0) 187 | lower_index = int(index) 188 | upper_index = min(lower_index + 1, len(sorted_values) - 1) 189 | 190 | return sorted_values[lower_index] + ( 191 | sorted_values[upper_index] - sorted_values[lower_index] 192 | ) * (index - lower_index) 193 | 194 | 195 | def compute_step_deviation_from_baseline( 196 | step_time_deviation: dict[int, float], 197 | mode: MonitoringWindowType = MonitoringWindowType.CUMULATIVE, 198 | recent_window_size: int = _DEFAULT_RECENT_WINDOW_SIZE, 199 | baseline_window_size: int = _DEFAULT_BASELINE_WINDOW_SIZE, 200 | spike_percentile: int = _DEFAULT_SPIKE_PERCENTILE, 201 | ) -> float: 202 | """Computes a spike-sensitive step time deviation metric. 203 | 204 | Args: 205 | step_time_deviation: Ordered dict (step count -> step deviation in seconds). 206 | mode: 'cumulative' to compare against a historical baseline; 'interval' to 207 | reflect short-term spikes only. 208 | recent_window_size: Number of recent steps to consider for interval mode. 209 | baseline_window_size: Number of older steps for cumulative baseline. 210 | spike_percentile: Percentile to use for recent deviation sensitivity. 211 | 212 | Returns: 213 | The step deviation from the baseline. 214 | """ 215 | if not step_time_deviation: 216 | return 0.0 217 | 218 | deviations = [abs(deviation) for deviation in step_time_deviation.values()] 219 | total_steps = len(deviations) 220 | 221 | if total_steps < _DEFAULT_RECENT_WINDOW_SIZE: 222 | return np.mean(deviations) 223 | 224 | if mode == MonitoringWindowType.INTERVAL: 225 | recent_deviations = deviations[-recent_window_size:] 226 | return compute_percentile(recent_deviations, spike_percentile) 227 | 228 | elif mode == MonitoringWindowType.CUMULATIVE: 229 | if total_steps < (recent_window_size + baseline_window_size): 230 | recent_deviations = deviations[-recent_window_size:] 231 | return compute_percentile(recent_deviations, spike_percentile) 232 | 233 | recent_deviations = deviations[-recent_window_size:] 234 | baseline_deviations = deviations[ 235 | -(recent_window_size + baseline_window_size) : -recent_window_size 236 | ] 237 | 238 | if not baseline_deviations: 239 | return compute_percentile(recent_deviations, spike_percentile) 240 | 241 | baseline_median = np.median(baseline_deviations) 242 | spike_value = compute_percentile(recent_deviations, spike_percentile) 243 | return spike_value - baseline_median 244 | 245 | else: 246 | raise ValueError('Unsupported MonitoringWindowType mode: {mode}') 247 | 248 | 249 | def compute_ideal_step_time(step_times: list[float]) -> Optional[float]: 250 | """Helper function to compute the ideal step time.""" 251 | # Filter out step times that may be less than 1 second. 252 | step_times = [step_time for step_time in step_times if step_time >= 1.0] 253 | if not step_times: 254 | return None 255 | # Compute the median absolute deviation (MAD) and median of the step times 256 | mad = stats.median_abs_deviation(step_times) 257 | med = np.median(step_times) 258 | 259 | # Normalize the step times to the median + 3 * MAD. 260 | normal_step_times = [ 261 | step_time for step_time in step_times if step_time <= (med + mad * 3) 262 | ] 263 | return np.mean(normal_step_times) if normal_step_times else None 264 | 265 | 266 | def get_anomalous_and_normal_step_times( 267 | step_times: list[Any], 268 | ) -> tuple[list[Any], list[Any]]: 269 | """Helper function to get anomalous and normal step times.""" 270 | mad = stats.median_abs_deviation(step_times) 271 | med = np.median(step_times) 272 | 273 | anomalous_step_times = [] 274 | normal_step_times = [] 275 | for step_time in step_times: 276 | if step_time > (med + mad * 3): 277 | anomalous_step_times.append(step_time) 278 | else: 279 | normal_step_times.append(step_time) 280 | 281 | return anomalous_step_times, normal_step_times 282 | 283 | 284 | def get_extra_time_from_anomalous_steps(step_times: list[Any]) -> float: 285 | anomalous_step_times, normal_step_times = get_anomalous_and_normal_step_times( 286 | step_times 287 | ) 288 | normal_step_mean = np.mean(normal_step_times) 289 | return sum(anomalous_step_times) - ( 290 | len(anomalous_step_times) * normal_step_mean 291 | ) 292 | 293 | 294 | def get_entry_time_from_log_entry( 295 | entry: dict[str, Any], 296 | ) -> Optional[EntryTime]: 297 | """Extracts the TimeEntry from a log entry.""" 298 | for entry_label, entry_value in entry.items(): 299 | if _TIME_ENTRY in entry_label and isinstance(entry_value, (int, float)): 300 | return EntryTime(field_name=entry_label, timestamp=float(entry_value)) 301 | return None 302 | 303 | 304 | def get_timestamp_from_log_entry( 305 | entry: dict[str, Any], 306 | ) -> Optional[datetime.datetime]: 307 | """Helper function to get the timestamp from a log entry.""" 308 | timestamp_posix_time = [ 309 | entry_value 310 | for entry_label, entry_value in entry.items() 311 | if _TIME_ENTRY in entry_label 312 | ] 313 | if timestamp_posix_time: 314 | return datetime.datetime.fromtimestamp( 315 | timestamp_posix_time[0], datetime.timezone.utc 316 | ) 317 | return None 318 | 319 | 320 | def get_gcp_metadata(category: str, attribute: str, timeout=5, retries=3): 321 | """Fetch the specified attribute from GCP metadata server. 322 | 323 | Args: 324 | category (str): The high-level metadata category (ex: 'instance', 325 | 'project'). 326 | attribute (str): The attribute to fetch under this category (ex: 'id', 327 | 'zone'). 328 | timeout (int): Timeout for the request in seconds. 329 | retries (int): Number of retry attempts for transient failures. 330 | 331 | Returns: 332 | str: The metadata value as a string, or None if the request fails. 333 | """ 334 | target_url = f'{_METADATA_SERVER_URL}{category}/{attribute}' 335 | 336 | session = requests.Session() 337 | retry_strategy = Retry( 338 | total=retries, 339 | backoff_factor=0.5, 340 | # Retry on the following status codes 341 | status_forcelist=[429, 500, 502, 503, 504], 342 | ) 343 | adapter = requests.adapters.HTTPAdapter(max_retries=retry_strategy) 344 | session.mount('http://', adapter) 345 | 346 | try: 347 | response = session.get( 348 | target_url, headers=_METADATA_HEADERS, timeout=timeout 349 | ) 350 | response.raise_for_status() 351 | return response.text 352 | except requests.exceptions.RequestException as e: 353 | logging.warning( 354 | 'Failed to retrieve metadata for %s/%s: %s', category, attribute, e 355 | ) 356 | return None 357 | 358 | 359 | def get_gcp_project_id(): 360 | """Returns the project id of the current GCP project.""" 361 | return get_gcp_metadata('project', 'project-id') 362 | 363 | 364 | def get_node_zone(): 365 | """Returns the zone of the GCE instance.""" 366 | zone_path = get_gcp_metadata('instance', 'zone') 367 | # example zone_path: "projects/123456789/zones/us-central1-a" 368 | return zone_path.rsplit('/', 1)[-1] if zone_path else None 369 | 370 | 371 | def get_accelerator_type(): 372 | """Retrieves the accelerator type from GCP metadata. 373 | 374 | For GKE TPU VMs, it extracts the type from the 'machine-type' metadata. 375 | 376 | Returns: 377 | str: The accelerator type, or 'UNKNOWN' if not found. 378 | """ 379 | machine_type_url = get_gcp_metadata('instance', 'machine-type') 380 | # example machine_type_url: "projects/123456789/machineTypes/a3-highgpu-8g" 381 | machine_type_name = ( 382 | machine_type_url.split('/')[-1] if machine_type_url else None 383 | ) 384 | 385 | if not machine_type_name: 386 | return 'UNKNOWN' 387 | 388 | for ( 389 | prefix, 390 | accelerator_type, 391 | ) in MACHINE_TYPE_TO_ACCELERATOR_TYPE_MAPPING.items(): 392 | if prefix.lower() in machine_type_name.lower(): 393 | return accelerator_type 394 | 395 | return 'UNKNOWN' 396 | -------------------------------------------------------------------------------- /ml_goodput_measurement/src/monitoring.py: -------------------------------------------------------------------------------- 1 | """Goodput monitoring API. 2 | 3 | This file contains all the utilities to monitor and upload goodput data of a 4 | user workload to Tensorboard asynchronously. 5 | """ 6 | 7 | import datetime 8 | import logging 9 | import math 10 | import os 11 | import threading 12 | import time 13 | 14 | from cloud_goodput.ml_goodput_measurement.src import gcp_metrics 15 | from cloud_goodput.ml_goodput_measurement.src import goodput 16 | from cloud_goodput.ml_goodput_measurement.src import goodput_utils 17 | from tensorboardX import writer 18 | 19 | BadputType = goodput_utils.BadputType 20 | GCPOptions = goodput_utils.GCPOptions 21 | GCPMetrics = gcp_metrics.GCPMetrics 22 | GoodputCalculator = goodput.GoodputCalculator 23 | IntervalMetricType = goodput_utils.IntervalMetricType 24 | IntervalWorkloadMetricDetails = goodput_utils.IntervalWorkloadMetricDetails 25 | MetricType = goodput_utils.MetricType 26 | MonitoringWindowType = goodput_utils.MonitoringWindowType 27 | ValueType = gcp_metrics.ValueType 28 | UnproductiveTimeDict = goodput.UnproductiveTimeDict 29 | WorkloadMetricDetails = goodput_utils.WorkloadMetricDetails 30 | 31 | ACTIVITY_EXCLUSION_LIST = goodput_utils.ACTIVITY_EXCLUSION_LIST 32 | _TENSORBOARD_GCS_SUBDIR = 'goodput' 33 | _TENSORBOARD_GOODPUT_LABEL = 'goodput' 34 | _TENSORBOARD_BADPUT_LABEL = 'badput' 35 | _TENSORBOARD_STEP_DEVIATION_LABEL = 'step_deviation' 36 | 37 | logger = logging.getLogger(__name__) 38 | 39 | 40 | class GoodputMonitor: 41 | """Queries and uploads goodput data to Tensorboard at a regular interval.""" 42 | 43 | def __init__( 44 | self, 45 | job_name: str, 46 | logger_name: str, 47 | tensorboard_dir: str, 48 | upload_interval: int, 49 | monitoring_enabled: bool = False, 50 | pathway_enabled: bool = False, 51 | include_badput_breakdown=False, 52 | include_step_deviation=False, 53 | configured_ideal_step_time=None, 54 | step_deviation_interval_seconds=10, 55 | gcp_options: GCPOptions = GCPOptions(), 56 | ): 57 | """Initializes the GoodputMonitor. 58 | 59 | Args: 60 | job_name: The name of the job to monitor. 61 | logger_name: The name of the Google Cloud Logging logger to use. 62 | tensorboard_dir: The directory to write TensorBoard data to. 63 | upload_interval: The interval to upload data to TensorBoard and GCP 64 | Monitoring. 65 | monitoring_enabled: Whether to enable monitoring. If the application is 66 | interested in monitoring Goodput, it should set this value to True if 67 | monitoring from TPU worker 0 andthe application's configurations 68 | request Goodput monitoring. 69 | pathway_enabled: Whether the application is using Pathways. 70 | include_badput_breakdown: Whether to query and upload badput breakdown 71 | data to Tensorboard. 72 | include_step_deviation: Whether to query and upload step deviation data 73 | to Tensorboard. 74 | configured_ideal_step_time: The optional ideal step time configured by 75 | the user. 76 | step_deviation_interval_seconds: The interval to query step deviation 77 | data. 78 | gcp_options: The options for Google Cloud Monitoring. 79 | """ 80 | if not monitoring_enabled: 81 | logger.info( 82 | 'Monitoring is disabled. Returning without initializing' 83 | ' GoodputMonitor.' 84 | ) 85 | return 86 | 87 | # Common configurations. 88 | self._job_name = job_name 89 | self._logger_name = logger_name 90 | self._tensorboard_dir = os.path.join( 91 | tensorboard_dir, _TENSORBOARD_GCS_SUBDIR 92 | ) 93 | # Goodput configurations. 94 | self._upload_interval = upload_interval 95 | self._include_badput_breakdown = include_badput_breakdown 96 | 97 | # Step deviation configurations. 98 | self._include_step_deviation = include_step_deviation 99 | self._step_deviation_interval_seconds = step_deviation_interval_seconds 100 | self._configured_ideal_step_time = configured_ideal_step_time 101 | 102 | # Initialize the GoodputCalculator. 103 | self._goodput_calculator = GoodputCalculator( 104 | job_name=self._job_name, 105 | logger_name=self._logger_name, 106 | using_pathways=pathway_enabled, 107 | ) 108 | self._writer = writer.SummaryWriter(self._tensorboard_dir) 109 | 110 | # Goodput uploader flags to signal the daemon thread if it exists when to 111 | # initate shutdown and wait for termination. 112 | self._goodput_uploader_thread_running = False 113 | self._goodput_upload_thread = None 114 | self._termination_event = threading.Event() 115 | self._termination_event.clear() 116 | 117 | # Step deviation threading flags. 118 | self._step_deviation_uploader_thread_running = False 119 | self._step_deviation_upload_thread = None 120 | self._step_deviation_termination_event = threading.Event() 121 | self._step_deviation_termination_event.clear() 122 | 123 | # Google Cloud Monitoring configurations. 124 | self._gcp_options = gcp_options 125 | self._metrics_sender = None 126 | 127 | # If step deviation is not included, disable GCP step deviation metrics. 128 | if not self._include_step_deviation: 129 | self._gcp_options.enable_gcp_step_deviation_metrics = False 130 | 131 | if ( 132 | self._gcp_options.enable_gcp_goodput_metrics 133 | or self._gcp_options.enable_gcp_step_deviation_metrics 134 | ): 135 | if not self._gcp_options.project_id: 136 | self._gcp_options.project_id = goodput_utils.get_gcp_project_id() 137 | if not self._gcp_options.location: 138 | self._gcp_options.location = goodput_utils.get_node_zone() 139 | if not self._gcp_options.acc_type: 140 | self._gcp_options.acc_type = goodput_utils.get_accelerator_type() 141 | if self._gcp_options.project_id and self._gcp_options.location: 142 | self._metrics_sender = GCPMetrics( 143 | project_id=self._gcp_options.project_id 144 | ) 145 | else: 146 | self._gcp_options.enable_gcp_goodput_metrics = False 147 | self._gcp_options.enable_gcp_step_deviation_metrics = False 148 | logger.warning( 149 | 'Project ID or location is not set. Google Cloud Monitoring will not be' 150 | ' enabled.' 151 | ) 152 | # Goodput interval uploader flags. 153 | self._interval_uploader_thread_running = False 154 | self._interval_goodput_upload_thread = None 155 | self._interval_termination_event = threading.Event() 156 | self._interval_termination_event.clear() 157 | self._rolling_windows = [] 158 | 159 | def __del__(self): 160 | try: 161 | self.stop_goodput_uploader() 162 | self.stop_step_deviation_uploader() 163 | self.stop_goodput_rolling_window_uploader() 164 | 165 | except Exception: # pylint: disable=broad-exception-caught 166 | pass 167 | 168 | def _log_tensorboard_scalars( 169 | self, 170 | label_prefix: str, 171 | data: dict[str, float | dict[str, float]], 172 | step: int, 173 | ): 174 | """Logs scalar values (flat or nested) to TensorBoard under a label prefix.""" 175 | if self._writer is None: 176 | return 177 | 178 | for data_type, data_value in data.items(): 179 | if isinstance(data_value, dict): 180 | for subtype, subval in data_value.items(): 181 | full_label = f'{label_prefix}/{data_type}/{subtype}'.lower() 182 | self._writer.add_scalar( 183 | full_label, float(subval), step, display_name=subtype.lower() 184 | ) 185 | else: 186 | full_label = f'{label_prefix}/{data_type.lower()}' 187 | self._writer.add_scalar( 188 | full_label, float(data_value), step, display_name=data_type.lower() 189 | ) 190 | 191 | self._writer.flush() 192 | 193 | def _upload_goodput_metrics_to_tensorboard( 194 | self, 195 | job_goodput: float, 196 | badput_breakdown: UnproductiveTimeDict, 197 | last_step: int, 198 | ): 199 | """Writes goodput and badput breakdown to Tensorboard.""" 200 | try: 201 | self._write_goodput_to_tensorboard(job_goodput, last_step) 202 | if self._include_badput_breakdown: 203 | self._write_badput_to_tensorboard(badput_breakdown, last_step) 204 | except Exception as e: # pylint: disable=broad-exception-caught 205 | logger.error( 206 | 'Error while writing goodput and badput data to Tensorboard. This' 207 | ' will not impact the workload. Error: %s', 208 | e, 209 | ) 210 | 211 | def _write_goodput_to_tensorboard(self, job_goodput: float, last_step: int): 212 | self._log_tensorboard_scalars( 213 | _TENSORBOARD_GOODPUT_LABEL, 214 | {_TENSORBOARD_GOODPUT_LABEL: job_goodput}, 215 | last_step, 216 | ) 217 | 218 | def _write_badput_to_tensorboard( 219 | self, 220 | job_badput_breakdown: UnproductiveTimeDict, 221 | last_step: int, 222 | ): 223 | """Writes badput breakdown to TensorBoard.""" 224 | flattened_badput: dict[str, float | dict[str, float]] = {} 225 | 226 | for badput_type, badput_value in job_badput_breakdown.items(): 227 | if isinstance(badput_value, dict): 228 | flattened_badput[badput_type.name.lower()] = { 229 | subtype.lower(): value for subtype, value in badput_value.items() 230 | } 231 | else: 232 | flattened_badput[badput_type.name.lower()] = badput_value 233 | 234 | self._log_tensorboard_scalars( 235 | _TENSORBOARD_BADPUT_LABEL, 236 | flattened_badput, 237 | last_step, 238 | ) 239 | 240 | def _flatten_badput_dict( 241 | self, 242 | badput_time_dict: UnproductiveTimeDict, 243 | ) -> list[tuple[str, float]]: 244 | """Flattens nested badput types into (label, value) pairs for export.""" 245 | flat_badput = [] 246 | for badput_type, val in badput_time_dict.items(): 247 | if isinstance(val, dict): 248 | for subtype, subval in val.items(): 249 | flat_badput.append((f'{badput_type.name}.{subtype.upper()}', subval)) 250 | else: 251 | flat_badput.append((badput_type.name, val)) 252 | return flat_badput 253 | 254 | def _upload_goodput_metrics_to_gcm( 255 | self, goodput_details: WorkloadMetricDetails 256 | ): 257 | """Sends goodput and badput metrics to GCM.""" 258 | try: 259 | gcm_metrics = [] 260 | 261 | # Populate goodput time metrics. 262 | for goodput_type, time_value in goodput_details[ 263 | MetricType.GOODPUT_TIME.value 264 | ].items(): 265 | if goodput_type.name in ACTIVITY_EXCLUSION_LIST: 266 | continue 267 | gcm_metrics.append({ 268 | 'metric_type': 'compute.googleapis.com/workload/goodput_time', 269 | 'value': time_value, 270 | 'value_type': ValueType.DOUBLE, 271 | 'metric_labels': { 272 | 'goodput_source': goodput_type.name, 273 | 'accelerator_type': self._gcp_options.acc_type, 274 | }, 275 | 'resource_type': 'compute.googleapis.com/Workload', 276 | 'resource_labels': { 277 | 'location': self._gcp_options.location, 278 | 'workload_id': self._job_name, 279 | 'replica_id': self._gcp_options.replica_id, 280 | }, 281 | }) 282 | 283 | # Populate badput time metrics. 284 | for badput_label, time_value in self._flatten_badput_dict( 285 | goodput_details[MetricType.BADPUT_TIME.value] 286 | ): 287 | if badput_label in ACTIVITY_EXCLUSION_LIST: 288 | continue 289 | gcm_metrics.append({ 290 | 'metric_type': 'compute.googleapis.com/workload/badput_time', 291 | 'value': time_value, 292 | 'value_type': ValueType.DOUBLE, 293 | 'metric_labels': { 294 | 'badput_source': badput_label, 295 | 'accelerator_type': self._gcp_options.acc_type, 296 | }, 297 | 'resource_type': 'compute.googleapis.com/Workload', 298 | 'resource_labels': { 299 | 'location': self._gcp_options.location, 300 | 'workload_id': self._job_name, 301 | 'replica_id': self._gcp_options.replica_id, 302 | }, 303 | }) 304 | 305 | # Populate disruption metrics. 306 | gcm_metrics.append({ 307 | 'metric_type': 'compute.googleapis.com/workload/disruptions', 308 | 'value': goodput_details[MetricType.DISRUPTION_COUNT.value], 309 | 'value_type': ValueType.INT, 310 | 'metric_labels': { 311 | 'accelerator_type': self._gcp_options.acc_type, 312 | 'window_type': MonitoringWindowType.CUMULATIVE.value, 313 | }, 314 | 'resource_type': 'compute.googleapis.com/Workload', 315 | 'resource_labels': { 316 | 'location': self._gcp_options.location, 317 | 'workload_id': self._job_name, 318 | 'replica_id': self._gcp_options.replica_id, 319 | }, 320 | }) 321 | 322 | # Populate max productive step metrics. 323 | gcm_metrics.append({ 324 | 'metric_type': 'compute.googleapis.com/workload/max_productive_steps', 325 | 'value': goodput_details[MetricType.MAX_PRODUCTIVE_STEP.value], 326 | 'value_type': ValueType.INT, 327 | 'metric_labels': { 328 | 'accelerator_type': self._gcp_options.acc_type, 329 | }, 330 | 'resource_type': 'compute.googleapis.com/Workload', 331 | 'resource_labels': { 332 | 'location': self._gcp_options.location, 333 | 'workload_id': self._job_name, 334 | 'replica_id': self._gcp_options.replica_id, 335 | }, 336 | }) 337 | 338 | # Populate step time deviation metrics. 339 | step_time_deviations = goodput_details[ 340 | MetricType.STEP_TIME_DEVIATION.value 341 | ] 342 | if step_time_deviations: 343 | step_time_deviation_from_baseline = ( 344 | goodput_utils.compute_step_deviation_from_baseline( 345 | step_time_deviations 346 | ) 347 | ) 348 | gcm_metrics.append({ 349 | 'metric_type': ( 350 | 'compute.googleapis.com/workload/step_time_deviation' 351 | ), 352 | 'value': step_time_deviation_from_baseline, 353 | 'value_type': ValueType.DOUBLE, 354 | 'metric_labels': { 355 | 'accelerator_type': self._gcp_options.acc_type, 356 | }, 357 | 'resource_type': 'compute.googleapis.com/Workload', 358 | 'resource_labels': { 359 | 'location': self._gcp_options.location, 360 | 'workload_id': self._job_name, 361 | 'replica_id': self._gcp_options.replica_id, 362 | }, 363 | }) 364 | 365 | # Populate total elapsed time metrics. 366 | gcm_metrics.append({ 367 | 'metric_type': 'compute.googleapis.com/workload/total_elapsed_time', 368 | 'value': goodput_details[MetricType.TOTAL_ELAPSED_TIME.value], 369 | 'value_type': ValueType.DOUBLE, 370 | 'metric_labels': { 371 | 'accelerator_type': self._gcp_options.acc_type, 372 | 'window_type': MonitoringWindowType.CUMULATIVE.value, 373 | }, 374 | 'resource_type': 'compute.googleapis.com/Workload', 375 | 'resource_labels': { 376 | 'location': self._gcp_options.location, 377 | 'workload_id': self._job_name, 378 | 'replica_id': self._gcp_options.replica_id, 379 | }, 380 | }) 381 | 382 | # Populate ideal step time metrics. 383 | gcm_metrics.append({ 384 | 'metric_type': 'compute.googleapis.com/workload/performance', 385 | 'value': goodput_details[MetricType.IDEAL_STEP_TIME.value], 386 | 'value_type': ValueType.DOUBLE, 387 | 'resource_type': 'compute.googleapis.com/Workload', 388 | 'resource_labels': { 389 | 'location': self._gcp_options.location, 390 | 'workload_id': self._job_name, 391 | 'replica_id': self._gcp_options.replica_id, 392 | }, 393 | }) 394 | 395 | # Send metrics to Google Cloud Monitoring. 396 | if self._metrics_sender and gcm_metrics: 397 | self._metrics_sender.send_metrics(gcm_metrics) 398 | 399 | except Exception as e: # pylint: disable=broad-exception-caught 400 | logger.error( 401 | 'Error while sending goodput metrics to GCM. This' 402 | ' will not impact the workload. Error: %s', 403 | e, 404 | ) 405 | 406 | def _query_and_upload_goodput(self): 407 | """Queries and uploads goodput data to Tensorboard.""" 408 | while not self._termination_event.is_set(): 409 | time.sleep(self._upload_interval) 410 | # Query metrics and update the cache. 411 | try: 412 | job_goodput, job_badput_breakdown, last_step = ( 413 | self._goodput_calculator.get_job_goodput( 414 | include_badput_breakdown=self._include_badput_breakdown 415 | ) 416 | ) 417 | except Exception as e: # pylint: disable=broad-exception-caught 418 | logger.error( 419 | 'Error while querying goodput. Skipping this cycle. Error: %s', e 420 | ) 421 | continue 422 | # Upload metrics to Tensorboard. 423 | self._upload_goodput_metrics_to_tensorboard( 424 | job_goodput, job_badput_breakdown, last_step 425 | ) 426 | 427 | # Upload metrics to Google Cloud Monitoring. 428 | if self._gcp_options.enable_gcp_goodput_metrics: 429 | self._upload_goodput_metrics_to_gcm( 430 | self._goodput_calculator.get_job_goodput_details() 431 | ) 432 | 433 | def _final_goodput_query_and_upload(self): 434 | """Performs final cumulative goodput query and uploads data to Tensorboard & GCM.""" 435 | logger.info( 436 | 'Final goodput query and upload for job: %s and logger: %s', 437 | self._job_name, 438 | self._logger_name, 439 | ) 440 | try: 441 | job_goodput, job_badput_breakdown, last_step = ( 442 | self._goodput_calculator.get_job_goodput( 443 | include_badput_breakdown=self._include_badput_breakdown 444 | ) 445 | ) 446 | self._upload_goodput_metrics_to_tensorboard( 447 | job_goodput, job_badput_breakdown, last_step 448 | ) 449 | if self._gcp_options.enable_gcp_goodput_metrics: 450 | self._upload_goodput_metrics_to_gcm( 451 | self._goodput_calculator.get_job_goodput_details() 452 | ) 453 | logger.info( 454 | 'Final goodput query and upload for job: %s and logger: %s completed' 455 | ' with total goodput: %.2f%%, last step: %d', 456 | self._job_name, 457 | self._logger_name, 458 | job_goodput, 459 | last_step, 460 | ) 461 | except Exception as e: # pylint: disable=broad-exception-caught 462 | logger.error( 463 | 'Error while performing final goodput query and upload for job: %s' 464 | ' and logger: %s. This will not impact the workload. Error: %s', 465 | self._job_name, 466 | self._logger_name, 467 | e, 468 | ) 469 | 470 | def start_goodput_uploader(self): 471 | """Starts the goodput uploader thread.""" 472 | if self._goodput_uploader_thread_running: 473 | raise RuntimeError('Goodput uploader thread is already running.') 474 | 475 | self._termination_event.clear() 476 | self._goodput_upload_thread = threading.Thread( 477 | target=self._query_and_upload_goodput, daemon=True 478 | ) 479 | logger.info( 480 | 'Starting goodput query and uploader thread in the background for job:' 481 | ' %s and logger: %s', 482 | self._job_name, 483 | self._logger_name, 484 | ) 485 | self._goodput_upload_thread.start() 486 | self._goodput_uploader_thread_running = True 487 | 488 | def stop_goodput_uploader(self): 489 | """Stops the cumulative goodput uploader thread and performs a final cumulative goodput upload.""" 490 | if not self._goodput_uploader_thread_running: 491 | raise RuntimeError('Cumulative goodput uploader thread is not running.') 492 | 493 | self._termination_event.set() 494 | if self._goodput_upload_thread is not None: 495 | logger.info( 496 | 'Waiting for cumulative goodput query and uploader thread to' 497 | ' complete.' 498 | ) 499 | self._goodput_upload_thread.join() 500 | self._goodput_upload_thread = None 501 | logger.info( 502 | 'Cumulative goodput query and uploader thread stopped. No more goodput' 503 | ' data will be uploaded to Tensorboard or GCM.' 504 | ) 505 | self._goodput_uploader_thread_running = False 506 | # Final goodput query and upload. 507 | self._final_goodput_query_and_upload() 508 | 509 | def _write_step_deviation_to_tensorboard( 510 | self, step_deviation: dict[int, float] 511 | ): 512 | if self._writer is not None: 513 | for step_count, step_deviation in step_deviation.items(): 514 | self._writer.add_scalar( 515 | _TENSORBOARD_STEP_DEVIATION_LABEL, 516 | float(step_deviation), 517 | step_count, 518 | ) 519 | self._writer.flush() 520 | 521 | def _send_step_deviation_metric_to_gcp(self, step_deviations): 522 | """Sends step deviation metric to GCM.""" 523 | try: 524 | if not step_deviations: 525 | logger.warning( 526 | 'Step deviation is empty. This will not impact the workload.' 527 | ) 528 | return 529 | avg_step_deviation = sum(step_deviations.values()) / len(step_deviations) 530 | 531 | if math.isnan(avg_step_deviation): 532 | logger.warning( 533 | 'Step deviation is NaN. This will not impact the workload.' 534 | ) 535 | return 536 | 537 | perf_metric = [{ 538 | 'metric_type': 'compute.googleapis.com/workload/performance', 539 | 'value': avg_step_deviation, 540 | 'value_type': ValueType.DOUBLE, 541 | 'resource_type': 'compute.googleapis.com/Workload', 542 | 'resource_labels': { 543 | 'location': self._gcp_options.location, 544 | 'workload_id': self._job_name, 545 | 'replica_id': self._gcp_options.replica_id, 546 | }, 547 | }] 548 | if self._metrics_sender: 549 | self._metrics_sender.send_metrics(perf_metric) 550 | except Exception as e: # pylint: disable=broad-exception-caught 551 | logger.error( 552 | 'Error while sending step deviation to GCM.' 553 | ' This will not impact the workload. Error: %s', 554 | e, 555 | ) 556 | 557 | def _query_and_upload_step_deviation_to_tensorboard_and_gcp(self): 558 | """Queries and uploads step deviation data to Tensorboard and GCM.""" 559 | try: 560 | step_deviation = self._goodput_calculator.get_step_deviation( 561 | self._configured_ideal_step_time 562 | ) 563 | self._write_step_deviation_to_tensorboard(step_deviation) 564 | if self._gcp_options.enable_gcp_step_deviation_metrics: 565 | self._send_step_deviation_metric_to_gcp(step_deviation) 566 | except Exception as e: # pylint: disable=broad-exception-caught 567 | logger.error( 568 | 'Error while querying and uploading step deviation to Tensorboard.' 569 | ' This will not impact the workload. Error: %s', 570 | e, 571 | ) 572 | 573 | def _query_and_upload_step_deviation(self): 574 | """Queries and uploads step deviation data to Tensorboard.""" 575 | while not self._step_deviation_termination_event.is_set(): 576 | time.sleep(self._step_deviation_interval_seconds) 577 | self._query_and_upload_step_deviation_to_tensorboard_and_gcp() 578 | 579 | def _final_step_deviation_query_and_upload(self): 580 | """Performs final step deviation query and uploads data to Tensorboard & GCM.""" 581 | logger.info( 582 | 'Final step deviation query and upload for job: %s and logger: %s', 583 | self._job_name, 584 | self._logger_name, 585 | ) 586 | try: 587 | step_deviation = self._goodput_calculator.get_step_deviation( 588 | self._configured_ideal_step_time 589 | ) 590 | self._write_step_deviation_to_tensorboard(step_deviation) 591 | if self._gcp_options.enable_gcp_step_deviation_metrics: 592 | self._send_step_deviation_metric_to_gcp(step_deviation) 593 | logger.info( 594 | 'Final step deviation query and upload for job: %s and logger: %s' 595 | ' completed', 596 | self._job_name, 597 | self._logger_name, 598 | ) 599 | except Exception as e: # pylint: disable=broad-exception-caught 600 | logger.error( 601 | 'Error while performing final step deviation query and upload for' 602 | ' job: %s and logger: %s. This will not impact the workload. Error:' 603 | ' %s', 604 | self._job_name, 605 | self._logger_name, 606 | e, 607 | ) 608 | 609 | def start_step_deviation_uploader(self): 610 | """Starts the step deviation uploader thread.""" 611 | if not self._include_step_deviation: 612 | logger.info( 613 | 'Step deviation monitoring is disabled. Returning without' 614 | ' initializing step deviation uploader thread.' 615 | ) 616 | return 617 | 618 | if self._step_deviation_uploader_thread_running: 619 | raise RuntimeError('Step deviation uploader thread is already running.') 620 | 621 | self._step_deviation_termination_event.clear() 622 | self._step_deviation_upload_thread = threading.Thread( 623 | target=self._query_and_upload_step_deviation, daemon=True 624 | ) 625 | logger.info( 626 | 'Starting step deviation query and uploader thread in the background' 627 | ' for job: %s and logger: %s', 628 | self._job_name, 629 | self._logger_name, 630 | ) 631 | self._step_deviation_upload_thread.start() 632 | self._step_deviation_uploader_thread_running = True 633 | 634 | def stop_step_deviation_uploader(self): 635 | """Stops the step deviation uploader thread.""" 636 | if not self._step_deviation_uploader_thread_running: 637 | raise RuntimeError('Step deviation uploader thread is not running.') 638 | 639 | self._step_deviation_termination_event.set() 640 | if self._step_deviation_upload_thread is not None: 641 | logger.info( 642 | 'Waiting for step deviation query and uploader thread to complete.' 643 | ) 644 | self._step_deviation_upload_thread.join() 645 | logger.info( 646 | 'Step deviation query and uploader thread stopped. No more step' 647 | ' deviation data will be uploaded to Tensorboard or GCM.' 648 | ) 649 | self._step_deviation_uploader_thread_running = False 650 | # Final step deviation query and upload. 651 | self._final_step_deviation_query_and_upload() 652 | 653 | def _final_rolling_window_goodput_query_and_upload(self): 654 | """Performs final rolling window goodput query and uploads data to GCM for all rolling windows.""" 655 | logger.info( 656 | 'Final rolling window goodput query and upload for job: %s and' 657 | ' logger: %s', 658 | self._job_name, 659 | self._logger_name, 660 | ) 661 | try: 662 | now = datetime.datetime.now(datetime.timezone.utc) 663 | 664 | # Perform the final upload for each rolling window. 665 | for window_size in self._rolling_windows: 666 | window_end = now 667 | window_start = now - datetime.timedelta(seconds=window_size) 668 | window_start = window_start.replace(tzinfo=datetime.timezone.utc) 669 | 670 | # Get rolling window metrics for the current window size. 671 | rolling_window_metric_details = ( 672 | self._goodput_calculator.get_interval_metric_details( 673 | window_start, window_end 674 | ) 675 | ) 676 | 677 | # Upload the metrics to GCM. 678 | self._upload_interval_goodput_metrics_to_gcm( 679 | rolling_window_metric_details 680 | ) 681 | 682 | logger.info( 683 | 'Final rolling window goodput query and upload for job: %s and' 684 | ' logger: %s completed.', 685 | self._job_name, 686 | self._logger_name, 687 | ) 688 | except Exception as e: # pylint: disable=broad-exception-caught 689 | logger.error( 690 | 'Error while performing final rolling window goodput query and upload' 691 | ' for job: %s and logger: %s. This will not impact the workload.' 692 | ' Error: %s', 693 | self._job_name, 694 | self._logger_name, 695 | e, 696 | ) 697 | 698 | def _upload_interval_goodput_metrics_to_gcm( 699 | self, 700 | interval_metric_details: IntervalWorkloadMetricDetails, 701 | ): 702 | """Uploads interval goodput metrics to GCM.""" 703 | try: 704 | gcm_metrics = [] 705 | window_size = interval_metric_details[ 706 | IntervalMetricType.INTERVAL_SIZE.value 707 | ] 708 | 709 | # Populate Interval Goodput. 710 | for goodput_type, goodput_value in interval_metric_details[ 711 | IntervalMetricType.INTERVAL_GOODPUT.value 712 | ].items(): 713 | if goodput_type.name in ACTIVITY_EXCLUSION_LIST: 714 | continue 715 | gcm_metrics.append({ 716 | 'metric_type': 'compute.googleapis.com/workload/interval_goodput', 717 | 'value': goodput_value, 718 | 'value_type': ValueType.DOUBLE, 719 | 'metric_labels': { 720 | 'goodput_source': goodput_type.name, 721 | 'accelerator_type': self._gcp_options.acc_type, 722 | 'rolling_window_size': str(window_size), 723 | }, 724 | 'resource_type': 'compute.googleapis.com/Workload', 725 | 'resource_labels': { 726 | 'location': self._gcp_options.location, 727 | 'workload_id': self._job_name, 728 | 'replica_id': self._gcp_options.replica_id, 729 | }, 730 | }) 731 | 732 | # Populate Interval Badput. 733 | for badput_type, badput_value in self._flatten_badput_dict( 734 | interval_metric_details[IntervalMetricType.INTERVAL_BADPUT.value] 735 | ): 736 | if badput_type in ACTIVITY_EXCLUSION_LIST: 737 | continue 738 | gcm_metrics.append({ 739 | 'metric_type': 'compute.googleapis.com/workload/interval_badput', 740 | 'value': badput_value, 741 | 'value_type': ValueType.DOUBLE, 742 | 'metric_labels': { 743 | 'badput_source': badput_type, 744 | 'accelerator_type': self._gcp_options.acc_type, 745 | 'rolling_window_size': str(window_size), 746 | }, 747 | 'resource_type': 'compute.googleapis.com/Workload', 748 | 'resource_labels': { 749 | 'location': self._gcp_options.location, 750 | 'workload_id': self._job_name, 751 | 'replica_id': self._gcp_options.replica_id, 752 | }, 753 | }) 754 | 755 | if self._metrics_sender: 756 | self._metrics_sender.send_metrics(gcm_metrics) 757 | 758 | except Exception as e: # pylint: disable=broad-exception-caught 759 | logger.error( 760 | 'Error while uploading interval goodput metrics to GCM. This will' 761 | ' not impact the workload. Error: %s', 762 | e, 763 | ) 764 | 765 | def _query_and_upload_rolling_window_goodput(self): 766 | """Queries and uploads rolling window goodput to GCM.""" 767 | while not self._interval_termination_event.is_set(): 768 | time.sleep(self._upload_interval) 769 | if not self._gcp_options.enable_gcp_goodput_metrics: 770 | continue 771 | 772 | now = datetime.datetime.now(datetime.timezone.utc) 773 | for window_size in self._rolling_windows: 774 | window_end = now 775 | window_start = now - datetime.timedelta(seconds=window_size) 776 | window_start = window_start.replace(tzinfo=datetime.timezone.utc) 777 | interval_metric_details = ( 778 | self._goodput_calculator.get_interval_metric_details( 779 | window_start, window_end 780 | ) 781 | ) 782 | self._upload_interval_goodput_metrics_to_gcm(interval_metric_details) 783 | 784 | def start_rolling_window_goodput_uploader( 785 | self, rolling_windows_seconds: list[int] 786 | ): 787 | """Starts the goodput uploader thread for user-specified interval windows.""" 788 | if self._interval_uploader_thread_running: 789 | raise RuntimeError('Goodput interval uploader thread is already running.') 790 | 791 | self._interval_termination_event.clear() 792 | self._rolling_windows = rolling_windows_seconds 793 | self._interval_goodput_upload_thread = threading.Thread( 794 | target=self._query_and_upload_rolling_window_goodput, 795 | daemon=True, 796 | ) 797 | logger.info( 798 | 'Starting rolling window goodput query and uploader thread in the' 799 | ' background for job: %s and logger: %s', 800 | self._job_name, 801 | self._logger_name, 802 | ) 803 | self._interval_goodput_upload_thread.start() 804 | self._interval_uploader_thread_running = True 805 | 806 | def stop_goodput_rolling_window_uploader(self): 807 | """Stops the rolling window goodput uploader thread and performs a final rolling window goodput upload.""" 808 | if not self._interval_uploader_thread_running: 809 | raise RuntimeError( 810 | 'Rolling window goodput uploader thread is not running.' 811 | ) 812 | 813 | self._interval_termination_event.set() 814 | if self._interval_goodput_upload_thread is not None: 815 | logger.info( 816 | 'Waiting for rolling window goodput query and uploader thread to' 817 | ' complete.' 818 | ) 819 | self._interval_goodput_upload_thread.join() 820 | self._interval_goodput_upload_thread = None 821 | logger.info( 822 | 'Rolling window goodput query and uploader thread stopped. No more' 823 | ' rolling window goodput data will be uploaded to GCM.' 824 | ) 825 | 826 | self._interval_uploader_thread_running = False 827 | 828 | # Perform the final rolling window goodput query and upload 829 | self._final_rolling_window_goodput_query_and_upload() 830 | -------------------------------------------------------------------------------- /ml_goodput_measurement/tests/checkpoint_badput_calculator_test.py: -------------------------------------------------------------------------------- 1 | """Tests for checkpoint badput calculator.""" 2 | 3 | import dataclasses 4 | from typing import Optional 5 | 6 | from absl.testing import absltest 7 | from cloud_goodput.ml_goodput_measurement.src import checkpoint_badput_calculator 8 | import google.cloud.logging as google_cloud_logging 9 | import mock 10 | 11 | 12 | _JOB_NAME = 'checkpoint_job' 13 | _LOGGER_NAME = 'checkpoint_logger' 14 | 15 | 16 | @dataclasses.dataclass 17 | class MockSaveStepStatistics: 18 | """Attributes for save step statistics. 19 | 20 | Attributes: 21 | step: The step number. 22 | event_type: The event type. 23 | checkpoint_manager_blocking_start_time: The start time of checkpoint manager 24 | blocking section. 25 | directory: The directory of the checkpoint. 26 | reached_preemption: Whether the event reached preemption. 27 | preemption_received_at: The time when preemption was received. 28 | wait_for_prev_start_time: The start time of waiting for previous checkpoint. 29 | checkpointer_blocking_start_time: The start time of blocking time introduced 30 | by checkpointer. 31 | get_old_steps_start_time: The start time of getting old steps. 32 | synchronous: Whether the event is synchronous. 33 | wait_for_prev_duration_secs: The duration of waiting for previous 34 | checkpoint. 35 | checkpointer_blocking_duration_secs: The duration of blocking time 36 | introduced by checkpointer. 37 | get_old_steps_duration_secs: The duration of getting old steps. 38 | checkpoint_manager_blocking_duration_secs: The duration of checkpoint 39 | manager blocking section. 40 | """ 41 | 42 | step: Optional[int] = None 43 | event_type: Optional[str] = 'save' 44 | directory: Optional[str] = None 45 | reached_preemption: Optional[bool] = False 46 | preemption_received_at: Optional[float] = None 47 | synchronous: Optional[bool] = False 48 | wait_for_prev_start_time: Optional[float] = None 49 | wait_for_prev_duration_secs: Optional[float] = None 50 | checkpointer_blocking_start_time: Optional[float] = None 51 | checkpointer_blocking_duration_secs: Optional[float] = None 52 | get_old_steps_start_time: Optional[float] = None 53 | get_old_steps_duration_secs: Optional[float] = None 54 | checkpoint_manager_blocking_start_time: Optional[float] = None 55 | checkpoint_manager_blocking_duration_secs: Optional[float] = None 56 | 57 | 58 | @dataclasses.dataclass 59 | class MockRestoreStepStatistics: 60 | """Attributes for restore step statistics. 61 | 62 | Attributes: 63 | step: The step number. 64 | event_type: The event type. 65 | directory: The directory of the checkpoint. 66 | checkpointer_start_time: The start time of restoring the checkpoint, while 67 | using the checkpointer. 68 | checkpointer_duration_secs: The total duration for restoring the checkpoint, 69 | while using the checkpointer. 70 | checkpoint_manager_start_time: The start time for restoring the checkpoint, 71 | while using the checkpoint manager. 72 | checkpoint_manager_duration_secs: The total duration for restoring the 73 | checkpoint, while using the checkpoint manager. 74 | """ 75 | 76 | step: Optional[int] = None 77 | event_type: Optional[str] = 'restore' 78 | directory: Optional[str] = None 79 | checkpointer_start_time: Optional[float] = None 80 | checkpointer_duration_secs: Optional[float] = None 81 | checkpoint_manager_start_time: Optional[float] = None 82 | checkpoint_manager_duration_secs: Optional[float] = None 83 | 84 | 85 | @dataclasses.dataclass 86 | class MockEmergencyRestoreStepStatistics: 87 | """Attributes for emergency restore step statistics. 88 | 89 | Attributes: 90 | step: The step number. 91 | event_type: The event type. 92 | checkpoint_manager_start_time: The start time of checkpoint manager 93 | restore event. 94 | directory: The directory of the checkpoint. 95 | is_restoring_slice: Whether the event takes place on the slice responsible 96 | for reading from the storage location. (Note that in_primary_slice=True 97 | necessarily implies is_restoring_slice=True.) 98 | in_primary_slice: Whether the event takes place on the slice designated as 99 | primary (responsible for restoring from persistent storage). 100 | checkpointer_start_time: The start time of restoring the checkpoint, while 101 | using the checkpointer. 102 | checkpointer_duration_secs: The total duration for restoring the checkpoint, 103 | while using the checkpointer. 104 | broadcast_start_time: The start time of broadcasting(Restore).The broadcast 105 | operation performed by SingleReplicaArrayHandler won't be captured in this 106 | context. 107 | broadcast_duration_secs: The duration of broadcasting(Restore). 108 | checkpoint_manager_duration_secs: The total duration of checkpoint 109 | manager restore event. 110 | """ 111 | 112 | step: Optional[int] = None 113 | event_type: Optional[str] = 'emergency_restore' 114 | checkpoint_manager_start_time: Optional[float] = None 115 | directory: Optional[str] = None 116 | is_restoring_slice: Optional[bool] = False 117 | in_primary_slice: Optional[bool] = False 118 | checkpointer_start_time: Optional[float] = None 119 | checkpointer_duration_secs: Optional[float] = None 120 | broadcast_start_time: Optional[float] = None 121 | broadcast_duration_secs: Optional[float] = None 122 | checkpoint_manager_duration_secs: Optional[float] = None 123 | 124 | 125 | class CheckpointBadputCalculatorTest(absltest.TestCase): 126 | 127 | def setUp(self): 128 | """Setup for the test.""" 129 | super().setUp() 130 | mock_gcloud_client = mock.create_autospec(google_cloud_logging.Client) 131 | options = checkpoint_badput_calculator.CheckpointLoggerOptions( 132 | job_name=_JOB_NAME, 133 | logger_name=_LOGGER_NAME, 134 | client=mock_gcloud_client, 135 | use_goodput_logger=True, 136 | ) 137 | self.checkpoint_badput_calculator = ( 138 | checkpoint_badput_calculator.CheckpointBadputCalculator(options) 139 | ) 140 | 141 | def test_checkpoint_badput_calculator_persistent_save_operation(self): 142 | """Test for persistent save operation.""" 143 | step_count = 4 144 | default_cm_blocking_duration_secs = 4 145 | default_ckptr_blocking_duration_secs = 1 146 | default_gos_duration_secs = 1 147 | default_wfp_duration_secs = 2 148 | for i in range(1, step_count+1): 149 | persistent_save_entry = dataclasses.asdict( 150 | MockSaveStepStatistics( 151 | step=i, 152 | event_type='save', 153 | directory='gs://bucket/path', 154 | wait_for_prev_start_time=i * 10.0, 155 | wait_for_prev_duration_secs=default_wfp_duration_secs, 156 | checkpointer_blocking_start_time=i * 10.0 + 2, 157 | checkpointer_blocking_duration_secs=default_ckptr_blocking_duration_secs, 158 | get_old_steps_start_time=i * 10.0 + 3, 159 | get_old_steps_duration_secs=default_gos_duration_secs, 160 | checkpoint_manager_blocking_start_time=i * 10.0, 161 | checkpoint_manager_blocking_duration_secs=default_cm_blocking_duration_secs, 162 | reached_preemption=True, 163 | preemption_received_at=i * 10.0, 164 | synchronous=True, 165 | ) 166 | ) 167 | self.checkpoint_badput_calculator.entries.append(persistent_save_entry) 168 | 169 | expected_breakdown = ( 170 | checkpoint_badput_calculator.SaveCheckpointManagerVerticalStepStats() 171 | ) 172 | expected_breakdown.total_checkpoint_manager_blocking_time = ( 173 | step_count * default_cm_blocking_duration_secs 174 | ) 175 | expected_breakdown.average_checkpoint_manager_blocking_time = ( 176 | default_cm_blocking_duration_secs 177 | ) 178 | expected_breakdown.minimum_checkpoint_manager_blocking_time = ( 179 | default_cm_blocking_duration_secs 180 | ) 181 | expected_breakdown.maximum_checkpoint_manager_blocking_time = ( 182 | default_cm_blocking_duration_secs 183 | ) 184 | expected_breakdown.standard_deviation_checkpoint_manager_blocking_time = 0 185 | expected_breakdown.total_checkpointer_blocking_time = ( 186 | step_count * default_ckptr_blocking_duration_secs 187 | ) 188 | expected_breakdown.average_checkpointer_blocking_time = ( 189 | default_ckptr_blocking_duration_secs 190 | ) 191 | expected_breakdown.minimum_checkpointer_blocking_time = ( 192 | default_ckptr_blocking_duration_secs 193 | ) 194 | expected_breakdown.maximum_checkpointer_blocking_time = ( 195 | default_ckptr_blocking_duration_secs 196 | ) 197 | expected_breakdown.standard_deviation_checkpointer_blocking_time = 0 198 | expected_breakdown.total_wait_for_prev_time = ( 199 | step_count * default_wfp_duration_secs 200 | ) 201 | expected_breakdown.average_wait_for_prev_time = default_wfp_duration_secs 202 | expected_breakdown.minimum_wait_for_prev_time = default_wfp_duration_secs 203 | expected_breakdown.maximum_wait_for_prev_time = default_wfp_duration_secs 204 | expected_breakdown.standard_deviation_wait_for_prev_time = 0 205 | expected_breakdown.total_get_old_steps_time = ( 206 | step_count * default_gos_duration_secs 207 | ) 208 | expected_breakdown.average_get_old_steps_time = default_gos_duration_secs 209 | expected_breakdown.minimum_get_old_steps_time = default_gos_duration_secs 210 | expected_breakdown.maximum_get_old_steps_time = default_gos_duration_secs 211 | expected_breakdown.standard_deviation_get_old_steps_time = 0 212 | 213 | cm_breakdown = ( 214 | self.checkpoint_badput_calculator.calculate_save_operation_checkpoint_manager_blocking_time( 215 | checkpoint_badput_calculator.OPERATION_TYPE_PERSISTENT 216 | ) 217 | ) 218 | for field in dataclasses.fields(cm_breakdown): 219 | value1 = getattr(cm_breakdown, field.name) 220 | value2 = getattr(expected_breakdown, field.name) 221 | if value1 != value2: 222 | raise ValueError( 223 | f"Mismatch in field '{field.name}':\n" 224 | f" Actual: {value1}\n" 225 | f" Expected: {value2}" 226 | ) 227 | 228 | def test_checkpoint_badput_calculator_local_save_operation(self): 229 | """Test for local save operation.""" 230 | step_count = 4 231 | default_cm_blocking_duration_secs = 4 232 | default_ckptr_blocking_duration_secs = 1 233 | default_gos_duration_secs = 1 234 | default_wfp_duration_secs = 2 235 | for i in range(1, step_count+1): 236 | local_save_entry = dataclasses.asdict( 237 | MockSaveStepStatistics( 238 | step=i, 239 | event_type='save', 240 | directory='local', 241 | wait_for_prev_start_time=i * 10.0, 242 | wait_for_prev_duration_secs=default_wfp_duration_secs, 243 | checkpointer_blocking_start_time=i * 10.0 + 2, 244 | checkpointer_blocking_duration_secs=default_ckptr_blocking_duration_secs, 245 | get_old_steps_start_time=i * 10.0 + 3, 246 | get_old_steps_duration_secs=default_gos_duration_secs, 247 | checkpoint_manager_blocking_start_time=i * 10.0, 248 | checkpoint_manager_blocking_duration_secs=default_cm_blocking_duration_secs, 249 | reached_preemption=True, 250 | preemption_received_at=i * 10.0, 251 | synchronous=True, 252 | ) 253 | ) 254 | self.checkpoint_badput_calculator.entries.append(local_save_entry) 255 | 256 | expected_breakdown = ( 257 | checkpoint_badput_calculator.SaveCheckpointManagerVerticalStepStats() 258 | ) 259 | expected_breakdown.total_checkpoint_manager_blocking_time = ( 260 | step_count * default_cm_blocking_duration_secs 261 | ) 262 | expected_breakdown.average_checkpoint_manager_blocking_time = ( 263 | default_cm_blocking_duration_secs 264 | ) 265 | expected_breakdown.minimum_checkpoint_manager_blocking_time = ( 266 | default_cm_blocking_duration_secs 267 | ) 268 | expected_breakdown.maximum_checkpoint_manager_blocking_time = ( 269 | default_cm_blocking_duration_secs 270 | ) 271 | expected_breakdown.standard_deviation_checkpoint_manager_blocking_time = 0 272 | expected_breakdown.total_checkpointer_blocking_time = ( 273 | step_count * default_ckptr_blocking_duration_secs 274 | ) 275 | expected_breakdown.average_checkpointer_blocking_time = ( 276 | default_ckptr_blocking_duration_secs 277 | ) 278 | expected_breakdown.minimum_checkpointer_blocking_time = ( 279 | default_ckptr_blocking_duration_secs 280 | ) 281 | expected_breakdown.maximum_checkpointer_blocking_time = ( 282 | default_ckptr_blocking_duration_secs 283 | ) 284 | expected_breakdown.standard_deviation_checkpointer_blocking_time = 0 285 | expected_breakdown.total_wait_for_prev_time = ( 286 | step_count * default_wfp_duration_secs 287 | ) 288 | expected_breakdown.average_wait_for_prev_time = default_wfp_duration_secs 289 | expected_breakdown.minimum_wait_for_prev_time = default_wfp_duration_secs 290 | expected_breakdown.maximum_wait_for_prev_time = default_wfp_duration_secs 291 | expected_breakdown.standard_deviation_wait_for_prev_time = 0 292 | expected_breakdown.total_get_old_steps_time = ( 293 | step_count * default_gos_duration_secs 294 | ) 295 | expected_breakdown.average_get_old_steps_time = default_gos_duration_secs 296 | expected_breakdown.minimum_get_old_steps_time = default_gos_duration_secs 297 | expected_breakdown.maximum_get_old_steps_time = default_gos_duration_secs 298 | expected_breakdown.standard_deviation_get_old_steps_time = 0 299 | 300 | cm_breakdown = ( 301 | self.checkpoint_badput_calculator.calculate_save_operation_checkpoint_manager_blocking_time( 302 | checkpoint_badput_calculator.OPERATION_TYPE_LOCAL 303 | ) 304 | ) 305 | for field in dataclasses.fields(cm_breakdown): 306 | value1 = getattr(cm_breakdown, field.name) 307 | value2 = getattr(expected_breakdown, field.name) 308 | if value1 != value2: 309 | raise ValueError( 310 | f"Mismatch in field '{field.name}':\n" 311 | f" Actual: {value1}\n" 312 | f" Expected: {value2}" 313 | ) 314 | 315 | def test_checkpoint_badput_calculator_persistent_restore_operation(self): 316 | """Test for persistent restore operation.""" 317 | step_count = 4 318 | default_cm_duration_secs = 4 319 | default_ckptr_duration_secs = 1 320 | for i in range(1, step_count+1): 321 | persitent_save_entry = dataclasses.asdict( 322 | MockRestoreStepStatistics( 323 | step=i, 324 | event_type='restore', 325 | directory='gs://bucket/path', 326 | checkpointer_start_time=i * 10.0, 327 | checkpointer_duration_secs=default_ckptr_duration_secs, 328 | checkpoint_manager_start_time=i * 10.0 + 2, 329 | checkpoint_manager_duration_secs=default_cm_duration_secs, 330 | ) 331 | ) 332 | self.checkpoint_badput_calculator.entries.append(persitent_save_entry) 333 | 334 | expected_breakdown = ( 335 | checkpoint_badput_calculator.RestoreCheckpointManagerVerticalStepStats() 336 | ) 337 | expected_breakdown.total_checkpoint_manager_time = ( 338 | step_count * default_cm_duration_secs 339 | ) 340 | expected_breakdown.average_checkpoint_manager_time = ( 341 | default_cm_duration_secs 342 | ) 343 | expected_breakdown.minimum_checkpoint_manager_time = ( 344 | default_cm_duration_secs 345 | ) 346 | expected_breakdown.maximum_checkpoint_manager_time = ( 347 | default_cm_duration_secs 348 | ) 349 | expected_breakdown.standard_deviation_checkpoint_manager_time = 0 350 | expected_breakdown.total_restore_time = ( 351 | step_count * default_ckptr_duration_secs 352 | ) 353 | expected_breakdown.average_restore_time = default_ckptr_duration_secs 354 | expected_breakdown.minimum_restore_time = default_ckptr_duration_secs 355 | expected_breakdown.maximum_restore_time = default_ckptr_duration_secs 356 | expected_breakdown.standard_deviation_restore_time = 0 357 | expected_breakdown.total_broadcast_time = 0 358 | expected_breakdown.average_broadcast_time = 0 359 | expected_breakdown.minimum_broadcast_time = 0 360 | expected_breakdown.maximum_broadcast_time = 0 361 | expected_breakdown.standard_deviation_broadcast_time = 0 362 | 363 | cm_breakdown = ( 364 | self.checkpoint_badput_calculator.calculate_restore_operation_checkpoint_manager_blocking_time( 365 | checkpoint_badput_calculator.OPERATION_TYPE_PERSISTENT 366 | ) 367 | ) 368 | for field in dataclasses.fields(cm_breakdown): 369 | value1 = getattr(cm_breakdown, field.name) 370 | value2 = getattr(expected_breakdown, field.name) 371 | if value1 != value2: 372 | raise ValueError( 373 | f"Mismatch in field '{field.name}':\n" 374 | f" Actual: {value1}\n" 375 | f" Expected: {value2}" 376 | ) 377 | 378 | def test_checkpoint_badput_calculator_local_restore_operation(self): 379 | """Test for local restore operation.""" 380 | step_count = 4 381 | default_cm_duration_secs = 4 382 | default_ckptr_duration_secs = 2 383 | default_broadcast_duration_secs = 2 384 | for i in range(1, step_count+1): 385 | local_save_entry = dataclasses.asdict( 386 | MockEmergencyRestoreStepStatistics( 387 | step=i, 388 | event_type='emergency_restore', 389 | directory='local', 390 | checkpointer_start_time=i * 10.0, 391 | checkpointer_duration_secs=default_ckptr_duration_secs, 392 | checkpoint_manager_start_time=i * 10.0 + 2, 393 | checkpoint_manager_duration_secs=default_cm_duration_secs, 394 | broadcast_start_time=i * 10.0 + 3, 395 | broadcast_duration_secs=default_broadcast_duration_secs, 396 | ) 397 | ) 398 | self.checkpoint_badput_calculator.entries.append(local_save_entry) 399 | 400 | expected_breakdown = ( 401 | checkpoint_badput_calculator.RestoreCheckpointManagerVerticalStepStats() 402 | ) 403 | expected_breakdown.total_checkpoint_manager_time = ( 404 | default_cm_duration_secs * step_count 405 | ) 406 | expected_breakdown.average_checkpoint_manager_time = ( 407 | default_cm_duration_secs 408 | ) 409 | expected_breakdown.minimum_checkpoint_manager_time = ( 410 | default_cm_duration_secs 411 | ) 412 | expected_breakdown.maximum_checkpoint_manager_time = ( 413 | default_cm_duration_secs 414 | ) 415 | expected_breakdown.standard_deviation_checkpoint_manager_time = 0 416 | expected_breakdown.total_restore_time = ( 417 | step_count * default_ckptr_duration_secs 418 | ) 419 | expected_breakdown.average_restore_time = default_ckptr_duration_secs 420 | expected_breakdown.minimum_restore_time = default_ckptr_duration_secs 421 | expected_breakdown.maximum_restore_time = default_ckptr_duration_secs 422 | expected_breakdown.standard_deviation_restore_time = 0 423 | expected_breakdown.total_broadcast_time = ( 424 | step_count * default_broadcast_duration_secs 425 | ) 426 | expected_breakdown.average_broadcast_time = default_broadcast_duration_secs 427 | expected_breakdown.minimum_broadcast_time = default_broadcast_duration_secs 428 | expected_breakdown.maximum_broadcast_time = default_broadcast_duration_secs 429 | expected_breakdown.standard_deviation_broadcast_time = 0 430 | 431 | cm_breakdown = ( 432 | self.checkpoint_badput_calculator.calculate_restore_operation_checkpoint_manager_blocking_time( 433 | checkpoint_badput_calculator.OPERATION_TYPE_LOCAL 434 | ) 435 | ) 436 | for field in dataclasses.fields(cm_breakdown): 437 | value1 = getattr(cm_breakdown, field.name) 438 | value2 = getattr(expected_breakdown, field.name) 439 | if value1 != value2: 440 | raise ValueError( 441 | f"Mismatch in field '{field.name}':\n" 442 | f" Actual: {value1}\n" 443 | f" Expected: {value2}" 444 | ) 445 | if __name__ == '__main__': 446 | absltest.main() 447 | -------------------------------------------------------------------------------- /ml_goodput_measurement/tests/gcp_metrics_test.py: -------------------------------------------------------------------------------- 1 | """Tests for GCP metrics.""" 2 | 3 | from unittest import mock 4 | 5 | from absl.testing import absltest 6 | from cloud_goodput.ml_goodput_measurement.src import gcp_metrics 7 | from google.api_core import exceptions 8 | from google.cloud import monitoring_v3 9 | 10 | 11 | ValueType = gcp_metrics.ValueType 12 | GCPMetrics = gcp_metrics.GCPMetrics 13 | patch = mock.patch 14 | GoogleAPIError = exceptions.GoogleAPIError 15 | 16 | 17 | class GCPMetricsTest(absltest.TestCase): 18 | 19 | @patch("google.cloud.monitoring_v3.MetricServiceClient") 20 | def setUp(self, mock_client): 21 | super().setUp() 22 | self.mock_client = mock_client.return_value 23 | self.project_id = "test-project" 24 | self.metrics_sender = GCPMetrics(self.project_id) 25 | 26 | def test_create_time_series(self): 27 | metric_type = "compute.googleapis.com/workload/goodput_time" 28 | value = 123.45 29 | value_type = ValueType.DOUBLE 30 | metric_labels = { 31 | "goodput_source": "TOTAL", 32 | "accelerator_type": "tpu-v5p", 33 | } 34 | resource_type = "compute.googleapis.com/Workload" 35 | resource_labels = { 36 | "location": "us-central1", 37 | "workload_id": "test-workload", 38 | "replica_id": "0", 39 | } 40 | seconds = 1677347200 41 | nanos = 123456789 42 | 43 | time_series = self.metrics_sender.create_time_series( 44 | metric_type, 45 | value, 46 | value_type, 47 | metric_labels, 48 | resource_type, 49 | resource_labels, 50 | seconds, 51 | nanos, 52 | ) 53 | 54 | # Assertions to check if the TimeSeries object is created correctly 55 | self.assertIsInstance(time_series, monitoring_v3.TimeSeries) 56 | self.assertEqual(time_series.metric.type, metric_type) 57 | self.assertEqual(time_series.resource.type, resource_type) 58 | self.assertEqual(time_series.resource.labels, resource_labels) 59 | self.assertEqual(time_series.metric.labels, metric_labels) 60 | 61 | # Correctly check the value based on value_type 62 | if value_type == ValueType.BOOL: 63 | self.assertEqual(time_series.points[0].value.bool_value, value) 64 | elif value_type == ValueType.INT: 65 | self.assertEqual(time_series.points[0].value.int64_value, value) 66 | elif value_type == ValueType.DOUBLE: 67 | self.assertEqual(time_series.points[0].value.double_value, value) 68 | elif value_type == ValueType.STRING: 69 | self.assertEqual(time_series.points[0].value.string_value, value) 70 | elif value_type == ValueType.DISTRIBUTION: 71 | self.assertEqual( 72 | time_series.points[0].value.distribution_value, value 73 | ) 74 | 75 | @patch("time.time") 76 | def test_send_metrics(self, mock_time): 77 | # Set a fixed return value for the mocked time.time() 78 | mock_time.return_value = 1677347200.5 79 | 80 | metrics_to_send = [ 81 | { 82 | "metric_type": "compute.googleapis.com/workload/goodput_time", 83 | "value": 42.0, 84 | "value_type": ValueType.DOUBLE, 85 | "resource_type": "test_resource", 86 | "resource_labels": {"loc": "us"}, 87 | }, 88 | { 89 | "metric_type": "compute.googleapis.com/workload/badput_time", 90 | "value": 10, 91 | "value_type": ValueType.INT, 92 | "metric_labels": {"source": "test2"}, 93 | "resource_type": "test_resource", 94 | "resource_labels": {"loc": "eu"}, 95 | }, 96 | ] 97 | 98 | self.metrics_sender.send_metrics(metrics_to_send) 99 | 100 | # Verify that create_time_series was called with the correct arguments 101 | expected_name = f"projects/{self.project_id}" 102 | expected_calls = [] 103 | for metric in metrics_to_send: 104 | metric_labels = metric.get("metric_labels", {}) 105 | series = self.metrics_sender.create_time_series( 106 | metric["metric_type"], 107 | metric["value"], 108 | metric["value_type"], 109 | metric_labels, 110 | metric["resource_type"], 111 | metric["resource_labels"], 112 | 1677347200, # seconds 113 | 500000000, # nanos 114 | ) 115 | expected_calls.append(series) 116 | 117 | self.mock_client.create_time_series.assert_called_once() 118 | _, kwargs = self.mock_client.create_time_series.call_args 119 | self.assertEqual(kwargs["name"], expected_name) 120 | # Check time series 121 | actual_series = kwargs["time_series"] 122 | self.assertEqual(len(actual_series), len(expected_calls)) 123 | for actual, expected in zip(actual_series, expected_calls): 124 | self.assertEqual(actual.metric.type, expected.metric.type) 125 | self.assertEqual(actual.resource.type, expected.resource.type) 126 | self.assertEqual(actual.resource.labels, expected.resource.labels) 127 | self.assertEqual(actual.metric.labels, expected.metric.labels) 128 | 129 | @patch("cloud_goodput.ml_goodput_measurement.src.gcp_metrics.logger.error") 130 | def test_send_metrics_failure(self, mock_logging_error): 131 | 132 | self.mock_client.create_time_series.side_effect = GoogleAPIError( 133 | "Test Error" 134 | ) 135 | 136 | metrics_to_send = [ 137 | { 138 | "metric_type": "compute.googleapis.com/workload/goodput_time", 139 | "value": 42.0, 140 | "value_type": ValueType.DOUBLE, 141 | "resource_type": "test_resource", 142 | "resource_labels": {"loc": "us"}, 143 | } 144 | ] 145 | 146 | self.metrics_sender.send_metrics(metrics_to_send) 147 | mock_logging_error.assert_called_once() 148 | 149 | if __name__ == "__main__": 150 | absltest.main() 151 | -------------------------------------------------------------------------------- /ml_goodput_measurement/tests/goodput_cache_test.py: -------------------------------------------------------------------------------- 1 | """Tests to unit test GoodputCache class.""" 2 | 3 | import datetime 4 | from unittest import mock 5 | 6 | from cloud_goodput.ml_goodput_measurement.src import goodput_cache 7 | from cloud_goodput.ml_goodput_measurement.src import goodput_utils 8 | from cloud_goodput.ml_goodput_measurement.src.goodput_utils import BadputType, GoodputInfo 9 | 10 | from google3.testing.pybase import googletest 11 | 12 | 13 | class GoodputCacheTest(googletest.TestCase): 14 | 15 | def setUp(self): 16 | super().setUp() 17 | self.goodput_cache = goodput_cache.GoodputCache() 18 | 19 | def test_update_cached_entries(self): 20 | mock_entries = [ 21 | {'time': 1, 'step': 1}, 22 | {'time': 2, 'step': 2}, 23 | {'time': 3, 'step': 3}, 24 | ] 25 | self.goodput_cache.update_cached_entries(mock_entries) 26 | self.assertFalse(self.goodput_cache.is_cache_empty()) 27 | self.assertEqual(self.goodput_cache.get_cached_entries(), mock_entries) 28 | 29 | def test_update_goodput_info(self): 30 | goodput_info = GoodputInfo( 31 | total_productive_time=100, 32 | total_elapsed_time=200, 33 | total_unproductive_time={ 34 | BadputType.TPU_INITIALIZATION: 10, 35 | BadputType.TRAINING_PREP: 10, 36 | BadputType.DATA_LOADING_SYNC: 30, 37 | BadputType.PROGRAM_STARTUP: 10, 38 | BadputType.UNPRODUCTIVE_CHECKPOINT_SAVE_TIME: 20, 39 | BadputType.UNPRODUCTIVE_CHECKPOINT_RESTORE_TIME: 10, 40 | BadputType.WASTED_PROGRESS_FROM_DISRUPTION: 10, 41 | BadputType.OTHER: 10, 42 | }, 43 | max_productive_step=3, 44 | last_recorded_step=3, 45 | number_of_disruptions=1, 46 | ) 47 | self.goodput_cache.update_goodput_info(goodput_info) 48 | self.assertEqual(self.goodput_cache._goodput_info, goodput_info) 49 | 50 | def test_clear_cache(self): 51 | mock_entries = [ 52 | {'time': 1, 'step': 1}, 53 | {'time': 2, 'step': 2}, 54 | {'time': 3, 'step': 3}, 55 | ] 56 | self.goodput_cache.update_cached_entries(mock_entries) 57 | self.goodput_cache.update_goodput_info( 58 | GoodputInfo( 59 | total_productive_time=100, 60 | total_elapsed_time=200, 61 | total_unproductive_time={ 62 | BadputType.TPU_INITIALIZATION: 10, 63 | BadputType.TRAINING_PREP: 10, 64 | BadputType.DATA_LOADING_SYNC: 30, 65 | BadputType.PROGRAM_STARTUP: 10, 66 | BadputType.UNPRODUCTIVE_CHECKPOINT_SAVE_TIME: 20, 67 | BadputType.UNPRODUCTIVE_CHECKPOINT_RESTORE_TIME: 10, 68 | BadputType.WASTED_PROGRESS_FROM_DISRUPTION: 10, 69 | BadputType.OTHER: 10, 70 | }, 71 | max_productive_step=3, 72 | last_recorded_step=3, 73 | number_of_disruptions=1, 74 | ) 75 | ) 76 | self.goodput_cache.clear_cache() 77 | self.assertEqual(self.goodput_cache.get_cached_entries(), []) 78 | self.assertIsNone(self.goodput_cache._goodput_info) 79 | self.assertIsNone(self.goodput_cache._last_entry_time) 80 | 81 | def test_is_cache_empty(self): 82 | self.assertTrue(self.goodput_cache.is_cache_empty()) 83 | self.goodput_cache.update_cached_entries([ 84 | {'time': 1, 'step': 1}, 85 | {'time': 2, 'step': 2}, 86 | {'time': 3, 'step': 3}, 87 | ]) 88 | self.assertFalse(self.goodput_cache.is_cache_empty()) 89 | 90 | def test_get_last_entry_time(self): 91 | self.assertIsNone(self.goodput_cache._last_entry_time) 92 | self.goodput_cache.update_cached_entries([ 93 | {'time': 1, 'step': 1}, 94 | {'time': 2, 'step': 2}, 95 | {'time': 3, 'step': 3}, 96 | ]) 97 | self.assertFalse(self.goodput_cache.is_cache_empty()) 98 | self.assertEqual( 99 | self.goodput_cache.get_last_entry_time(), 100 | goodput_utils.EntryTime('time', 3), 101 | ) 102 | 103 | def test_get_step_info(self): 104 | step_info = goodput_utils.StepInfo( 105 | step_deviations={1: 1.0, 2: 2.0}, 106 | ideal_step_time=1.0, 107 | ) 108 | self.goodput_cache.update_step_info(step_info) 109 | self.assertEqual(self.goodput_cache._step_info, step_info) 110 | 111 | def test_update_job_start_time(self): 112 | self.assertIsNone(self.goodput_cache._job_start_time) 113 | self.goodput_cache.update_cached_entries([ 114 | {'step_start_time': 2, 'step': 1}, 115 | {'step_start_time': 3, 'step': 2}, 116 | {'job_end_time': 4}, 117 | ]) 118 | self.assertIsNone(self.goodput_cache._job_start_time) 119 | self.goodput_cache.update_cached_entries([ 120 | {'job_start_time': 1}, 121 | {'job_start_time': 9}, 122 | {'step_start_time': 2, 'step': 1}, 123 | {'step_start_time': 3, 'step': 2}, 124 | {'job_end_time': 4}, 125 | ]) 126 | self.assertEqual( 127 | self.goodput_cache._job_start_time, 128 | datetime.datetime.fromtimestamp(1, tz=datetime.timezone.utc), 129 | ) 130 | 131 | def test_update_job_end_time(self): 132 | self.assertIsNone(self.goodput_cache._job_end_time) 133 | self.goodput_cache.update_cached_entries([ 134 | {'job_end_time': 1}, 135 | {'job_end_time': 2}, 136 | {'job_end_time': 3}, 137 | ]) 138 | self.assertEqual( 139 | self.goodput_cache._job_end_time, 140 | datetime.datetime.fromtimestamp(3, tz=datetime.timezone.utc), 141 | ) 142 | 143 | 144 | if __name__ == '__main__': 145 | googletest.main() 146 | -------------------------------------------------------------------------------- /ml_goodput_measurement/tests/monitoring_test.py: -------------------------------------------------------------------------------- 1 | """Tests to validate the monitoring module. 2 | 3 | This module tests the GoodputMonitor class and its functionality, specifically 4 | the uploading of step deviation, goodput and badput data to Tensorboard. 5 | """ 6 | 7 | from unittest import mock 8 | 9 | from absl.testing import absltest 10 | from cloud_goodput.ml_goodput_measurement.src import gcp_metrics 11 | from cloud_goodput.ml_goodput_measurement.src import goodput_utils 12 | from cloud_goodput.ml_goodput_measurement.src import monitoring 13 | 14 | from google.cloud import monitoring_v3 15 | 16 | BadputType = goodput_utils.BadputType 17 | GCPOptions = goodput_utils.GCPOptions 18 | GoodputMonitor = monitoring.GoodputMonitor 19 | GoodputType = goodput_utils.GoodputType 20 | IntervalMetricType = goodput_utils.IntervalMetricType 21 | MagicMock = mock.MagicMock 22 | MetricType = goodput_utils.MetricType 23 | ValueType = gcp_metrics.ValueType 24 | 25 | patch = mock.patch 26 | _TEST_UPLOAD_INTERVAL = 1 27 | 28 | 29 | class GoodputMonitorTests(absltest.TestCase): 30 | """Tests for the GoodputMonitor class.""" 31 | 32 | def setUp(self): 33 | super().setUp() 34 | self.job_name = 'test-run' 35 | self.logger_name = 'test-logger' 36 | self.tensorboard_dir = 'test-dir' 37 | 38 | def _create_timeseries( 39 | self, metric_type: str, labels: dict, value: float 40 | ) -> monitoring_v3.TimeSeries: 41 | ts = monitoring_v3.TimeSeries() 42 | ts.metric.type = metric_type 43 | ts.metric.labels.update(labels) 44 | ts.resource.type = 'compute.googleapis.com/Workload' 45 | ts.resource.labels.update({ 46 | 'location': 'test-location', 47 | 'workload_id': 'test-run', 48 | 'replica_id': 'test-replica-id', 49 | }) 50 | ts.points.append( 51 | monitoring_v3.Point( 52 | value=monitoring_v3.TypedValue(double_value=value), 53 | ) 54 | ) 55 | return ts 56 | 57 | def _compare_calls_ignore_time_series( 58 | self, expected_call, actual_call 59 | ) -> bool: 60 | if ( 61 | expected_call.args != actual_call.args 62 | or expected_call.kwargs.keys() != actual_call.kwargs.keys() 63 | ): 64 | return False 65 | 66 | for key, expected_value in expected_call.kwargs.items(): 67 | actual_value = actual_call.kwargs[key] 68 | if key == 'time_series': 69 | continue 70 | if expected_value != actual_value: 71 | return False 72 | 73 | return True 74 | 75 | def _setup_mock_goodput_monitor( 76 | self, mock_logging_client, mock_summary_writer, mock_metric_service_client 77 | ) -> GoodputMonitor: 78 | mock_client = MagicMock() 79 | mock_metric_service_client.return_value = mock_client 80 | mock_logging_client.return_value = MagicMock() 81 | mock_summary_writer.return_value = MagicMock() 82 | 83 | gcp_options = GCPOptions( 84 | enable_gcp_goodput_metrics=True, 85 | project_id='test-project', 86 | location='test-location', 87 | acc_type='test-acc-type', 88 | replica_id='test-replica-id', 89 | ) 90 | 91 | return GoodputMonitor( 92 | job_name='test-run', 93 | logger_name='test-logger', 94 | tensorboard_dir='/tmp', 95 | upload_interval=1, 96 | monitoring_enabled=True, 97 | gcp_options=gcp_options, 98 | ) 99 | 100 | @patch('tensorboardX.writer.SummaryWriter') 101 | @patch('google.cloud.logging.Client') 102 | def test_goodput_monitor_init(self, mock_logger_client, mock_summary_writer): 103 | mock_summary_writer.return_value = MagicMock() 104 | mock_logger_client.return_value = MagicMock() 105 | goodput_monitor = GoodputMonitor( 106 | self.job_name, 107 | self.logger_name, 108 | self.tensorboard_dir, 109 | upload_interval=_TEST_UPLOAD_INTERVAL, 110 | monitoring_enabled=True, 111 | ) 112 | # Objects should be initialized correctly. 113 | self.assertIsNotNone(goodput_monitor) 114 | self.assertIs(goodput_monitor._writer, mock_summary_writer.return_value) 115 | self.assertIsNotNone(goodput_monitor._goodput_calculator) 116 | 117 | # Thread events should be initialized correctly. 118 | self.assertIsNotNone(goodput_monitor._step_deviation_termination_event) 119 | self.assertFalse(goodput_monitor._step_deviation_termination_event.is_set()) 120 | self.assertFalse(goodput_monitor._step_deviation_uploader_thread_running) 121 | self.assertIsNotNone(goodput_monitor._termination_event) 122 | self.assertFalse(goodput_monitor._termination_event.is_set()) 123 | self.assertFalse(goodput_monitor._goodput_uploader_thread_running) 124 | 125 | @patch( 126 | 'cloud_goodput.ml_goodput_measurement.src.monitoring.GoodputMonitor._write_goodput_to_tensorboard' 127 | ) 128 | @patch('tensorboardX.writer.SummaryWriter') 129 | @patch('google.cloud.logging.Client') 130 | async def test_goodput_monitor_start_goodput_uploader_success( 131 | self, mock_logger_client, mock_summary_writer, mock_goodput_to_tensorboard 132 | ): 133 | mock_summary_writer.return_value = MagicMock() 134 | mock_goodput_to_tensorboard.return_value = MagicMock() 135 | mock_logger_client.return_value = MagicMock() 136 | goodput_monitor = monitoring.GoodputMonitor( 137 | self.job_name, 138 | self.logger_name, 139 | self.tensorboard_dir, 140 | upload_interval=_TEST_UPLOAD_INTERVAL, 141 | monitoring_enabled=True, 142 | ) 143 | goodput_monitor.start_goodput_uploader() 144 | self.assertTrue(goodput_monitor._uploader_thread_running) 145 | self.assertIsNotNone(goodput_monitor._goodput_upload_thread) 146 | self.assertFalse(goodput_monitor._termination_event.is_set()) 147 | mock_goodput_to_tensorboard.assert_called_once() 148 | mock_summary_writer.return_value.add_scalar.assert_called_once() 149 | goodput_monitor.stop_goodput_uploader() 150 | self.assertFalse(goodput_monitor._uploader_thread_running) 151 | self.assertIsNone(goodput_monitor._goodput_upload_thread) 152 | self.assertTrue(goodput_monitor._termination_event.is_set()) 153 | 154 | @patch( 155 | 'cloud_goodput.ml_goodput_measurement.src.monitoring.GoodputMonitor._write_goodput_to_tensorboard' 156 | ) 157 | @patch('tensorboardX.writer.SummaryWriter') 158 | @patch('google.cloud.logging.Client') 159 | async def test_goodput_monitor_start_goodput_uploader_failure( 160 | self, mock_logger_client, mock_summary_writer, mock_goodput_to_tensorboard 161 | ): 162 | mock_logger_client.return_value = MagicMock() 163 | mock_summary_writer.return_value = MagicMock() 164 | mock_goodput_to_tensorboard.side_effect = ValueError('Test Error') 165 | goodput_monitor = monitoring.GoodputMonitor( 166 | self.job_name, 167 | self.logger_name, 168 | self.tensorboard_dir, 169 | upload_interval=_TEST_UPLOAD_INTERVAL, 170 | monitoring_enabled=True, 171 | ) 172 | goodput_monitor.start_goodput_uploader() 173 | self.assertTrue(goodput_monitor._uploader_thread_running) 174 | self.assertIsNotNone(goodput_monitor._goodput_upload_thread) 175 | self.assertFalse(goodput_monitor._termination_event.is_set()) 176 | mock_goodput_to_tensorboard.assert_called_once() 177 | with self.assertRaisesRegex(ValueError, 'Test Error'): 178 | goodput_monitor._query_and_upload_goodput() 179 | mock_summary_writer.return_value.add_scalar.assert_not_called() 180 | goodput_monitor.stop_goodput_uploader() 181 | self.assertFalse(goodput_monitor._uploader_thread_running) 182 | self.assertIsNone(goodput_monitor._goodput_upload_thread) 183 | self.assertTrue(goodput_monitor._termination_event.is_set()) 184 | 185 | @patch( 186 | 'cloud_goodput.ml_goodput_measurement.src.monitoring.GoodputMonitor._write_badput_to_tensorboard' 187 | ) 188 | @patch('tensorboardX.writer.SummaryWriter') 189 | @patch('google.cloud.logging.Client') 190 | async def test_goodput_monitor_start_badput_uploader_success( 191 | self, mock_logger_client, mock_summary_writer, mock_badput_to_tensorboard 192 | ): 193 | mock_summary_writer.return_value = MagicMock() 194 | mock_badput_to_tensorboard.return_value = MagicMock() 195 | mock_logger_client.return_value = MagicMock() 196 | goodput_monitor = monitoring.GoodputMonitor( 197 | self.job_name, 198 | self.logger_name, 199 | self.tensorboard_dir, 200 | upload_interval=_TEST_UPLOAD_INTERVAL, 201 | monitoring_enabled=True, 202 | include_badput_breakdown=True, 203 | ) 204 | 205 | goodput_monitor.start_goodput_uploader() 206 | self.assertTrue(goodput_monitor._uploader_thread_running) 207 | self.assertIsNotNone(goodput_monitor._goodput_upload_thread) 208 | self.assertFalse(goodput_monitor._termination_event.is_set()) 209 | self.assertTrue(goodput_monitor._include_badput_breakdown) 210 | 211 | mock_badput_to_tensorboard.assert_called_once() 212 | mock_summary_writer.return_value.add_scalar.assert_called_once() 213 | 214 | goodput_monitor.stop_goodput_uploader() 215 | self.assertFalse(goodput_monitor._uploader_thread_running) 216 | self.assertIsNone(goodput_monitor._goodput_upload_thread) 217 | self.assertTrue(goodput_monitor._termination_event.is_set()) 218 | 219 | @patch( 220 | 'cloud_goodput.ml_goodput_measurement.src.monitoring.GoodputMonitor._write_step_deviation_to_tensorboard' 221 | ) 222 | @patch('tensorboardX.writer.SummaryWriter') 223 | @patch('google.cloud.logging.Client') 224 | async def test_goodput_monitor_start_step_deviation_uploader_success( 225 | self, 226 | mock_logger_client, 227 | mock_summary_writer, 228 | mock_step_deviation_to_tensorboard, 229 | ): 230 | mock_logger_client.return_value = MagicMock() 231 | mock_summary_writer.return_value = MagicMock() 232 | mock_step_deviation_to_tensorboard.return_value = MagicMock() 233 | goodput_monitor = monitoring.GoodputMonitor( 234 | self.job_name, 235 | self.logger_name, 236 | self.tensorboard_dir, 237 | upload_interval=_TEST_UPLOAD_INTERVAL, 238 | monitoring_enabled=True, 239 | include_step_deviation=True, 240 | ) 241 | goodput_monitor.start_step_deviation_uploader() 242 | self.assertTrue(goodput_monitor._step_deviation_uploader_thread_running) 243 | self.assertIsNotNone(goodput_monitor._step_deviation_upload_thread) 244 | self.assertFalse(goodput_monitor._step_deviation_termination_event.is_set()) 245 | mock_step_deviation_to_tensorboard.assert_called_once() 246 | mock_summary_writer.return_value.add_scalar.assert_called_once() 247 | goodput_monitor.stop_step_deviation_uploader() 248 | self.assertFalse(goodput_monitor._step_deviation_uploader_thread_running) 249 | self.assertIsNone(goodput_monitor._step_deviation_upload_thread) 250 | self.assertTrue(goodput_monitor._step_deviation_termination_event.is_set()) 251 | 252 | @patch( 253 | 'cloud_goodput.ml_goodput_measurement.src.monitoring.GoodputMonitor._write_step_deviation_to_tensorboard' 254 | ) 255 | @patch('tensorboardX.writer.SummaryWriter') 256 | @patch('google.cloud.logging.Client') 257 | async def test_goodput_monitor_start_step_deviation_uploader_failure( 258 | self, 259 | mock_logger_client, 260 | mock_summary_writer, 261 | mock_query_and_upload_step_deviation, 262 | ): 263 | mock_logger_client.return_value = MagicMock() 264 | mock_summary_writer.return_value = MagicMock() 265 | mock_query_and_upload_step_deviation.side_effect = ValueError('Test Error') 266 | goodput_monitor = monitoring.GoodputMonitor( 267 | self.job_name, 268 | self.logger_name, 269 | self.tensorboard_dir, 270 | upload_interval=_TEST_UPLOAD_INTERVAL, 271 | monitoring_enabled=True, 272 | include_step_deviation=True, 273 | ) 274 | goodput_monitor.start_step_deviation_uploader() 275 | self.assertTrue(goodput_monitor._step_deviation_uploader_thread_running) 276 | self.assertIsNotNone(goodput_monitor._step_deviation_upload_thread) 277 | self.assertFalse(goodput_monitor._step_deviation_termination_event.is_set()) 278 | mock_query_and_upload_step_deviation.assert_called_once() 279 | with self.assertRaisesRegex(ValueError, 'Test Error'): 280 | goodput_monitor._query_and_upload_step_deviation() 281 | mock_summary_writer.return_value.add_scalar.assert_not_called() 282 | goodput_monitor.stop_step_deviation_uploader() 283 | self.assertFalse(goodput_monitor._step_deviation_uploader_thread_running) 284 | self.assertIsNone(goodput_monitor._step_deviation_upload_thread) 285 | self.assertTrue(goodput_monitor._step_deviation_termination_event.is_set()) 286 | 287 | @patch('google.cloud.monitoring_v3.MetricServiceClient') 288 | @patch('tensorboardX.writer.SummaryWriter') 289 | @patch('google.cloud.logging.Client') 290 | def test_send_goodput_metrics_to_gcp_success( 291 | self, 292 | mock_logging_client, 293 | mock_summary_writer, 294 | mock_metric_service_client, 295 | ): 296 | mock_client = MagicMock() 297 | mock_metric_service_client.return_value = mock_client 298 | mock_logging_client.return_value = MagicMock() 299 | mock_summary_writer.return_value = MagicMock() 300 | 301 | gcp_options = GCPOptions( 302 | enable_gcp_goodput_metrics=True, 303 | project_id='test-project', 304 | location='test-location', 305 | acc_type='test-acc-type', 306 | replica_id='test-replica-id', 307 | ) 308 | 309 | goodput_monitor = GoodputMonitor( 310 | self.job_name, 311 | self.logger_name, 312 | self.tensorboard_dir, 313 | upload_interval=_TEST_UPLOAD_INTERVAL, 314 | monitoring_enabled=True, 315 | gcp_options=gcp_options, 316 | ) 317 | 318 | # Mock the get_job_goodput_details to return test data 319 | goodput_monitor._goodput_calculator.get_job_goodput_details = MagicMock( 320 | return_value={ 321 | MetricType.GOODPUT_TIME.value: { 322 | GoodputType.TOTAL: 10.0, 323 | }, 324 | MetricType.BADPUT_TIME.value: { 325 | BadputType.TPU_INITIALIZATION: 2.0, 326 | BadputType.DATA_LOADING_SYNC: 1.0, 327 | }, 328 | MetricType.DISRUPTION_COUNT.value: 0, 329 | MetricType.MAX_PRODUCTIVE_STEP.value: 2, 330 | MetricType.TOTAL_ELAPSED_TIME.value: 20.0, 331 | MetricType.STEP_TIME_DEVIATION.value: { 332 | 0: 1.0, 333 | 1: 1.0, 334 | 2: 1.0, 335 | }, 336 | MetricType.IDEAL_STEP_TIME.value: 1.0, 337 | } 338 | ) 339 | 340 | goodput_monitor._upload_goodput_metrics_to_gcm( 341 | goodput_monitor._goodput_calculator.get_job_goodput_details() 342 | ) 343 | 344 | expected_calls = [ 345 | mock.call.create_time_series( 346 | name='projects/test-project', 347 | time_series=[ 348 | self._create_timeseries( 349 | 'compute.googleapis.com/workload/goodput_time', 350 | { 351 | 'goodput_source': 'TOTAL', 352 | 'accelerator_type': 'test-acc-type', 353 | }, 354 | 10.0, 355 | ) 356 | ], 357 | ), 358 | mock.call.create_time_series( 359 | name='projects/test-project', 360 | time_series=[ 361 | self._create_timeseries( 362 | 'compute.googleapis.com/workload/badput_time', 363 | { 364 | 'badput_source': 'TPU_INITIALIZATION', 365 | 'accelerator_type': 'test-acc-type', 366 | }, 367 | 2.0, 368 | ) 369 | ], 370 | ), 371 | mock.call.create_time_series( 372 | name='projects/test-project', 373 | time_series=[ 374 | self._create_timeseries( 375 | 'compute.googleapis.com/workload/badput_time', 376 | { 377 | 'badput_source': 'DATA_LOADING_SYNC', 378 | 'accelerator_type': 'test-acc-type', 379 | }, 380 | 1.0, 381 | ) 382 | ], 383 | ), 384 | mock.call.create_time_series( 385 | name='projects/test-project', 386 | time_series=[ 387 | self._create_timeseries( 388 | 'compute.googleapis.com/workload/disruptions', 389 | { 390 | 'accelerator_type': 'test-acc-type', 391 | 'window_type': 'CUMULATIVE', 392 | }, 393 | 0, 394 | ) 395 | ], 396 | ), 397 | mock.call.create_time_series( 398 | name='projects/test-project', 399 | time_series=[ 400 | self._create_timeseries( 401 | 'compute.googleapis.com/workload/max_productive_steps', 402 | { 403 | 'accelerator_type': 'test-acc-type', 404 | }, 405 | 2, 406 | ) 407 | ], 408 | ), 409 | mock.call.create_time_series( 410 | name='projects/test-project', 411 | time_series=[ 412 | self._create_timeseries( 413 | 'compute.googleapis.com/workload/total_elapsed_time', 414 | { 415 | 'accelerator_type': 'test-acc-type', 416 | 'window_type': 'CUMULATIVE', 417 | }, 418 | 20.0, 419 | ) 420 | ], 421 | ), 422 | mock.call.create_time_series( 423 | name='projects/test-project', 424 | time_series=[ 425 | self._create_timeseries( 426 | 'compute.googleapis.com/workload/step_time_deviation', 427 | { 428 | 'accelerator_type': 'test-acc-type', 429 | }, 430 | 1.0, 431 | ) 432 | ], 433 | ), 434 | mock.call.create_time_series( 435 | name='projects/test-project', 436 | time_series=[ 437 | self._create_timeseries( 438 | 'compute.googleapis.com/workload/performance', 439 | { 440 | 'accelerator_type': 'test-acc-type', 441 | }, 442 | 1.0, 443 | ) 444 | ], 445 | ), 446 | ] 447 | 448 | actual_calls = mock_client.create_time_series.call_args_list 449 | 450 | # Verify each call individually 451 | for expected_call in expected_calls: 452 | self.assertTrue( 453 | any( 454 | self._compare_calls_ignore_time_series(expected_call, actual) 455 | for actual in actual_calls 456 | ), 457 | f'Expected call not found: {expected_call}', 458 | ) 459 | 460 | @patch('google.cloud.monitoring_v3.MetricServiceClient') 461 | @patch('tensorboardX.writer.SummaryWriter') 462 | @patch('google.cloud.logging.Client') 463 | def test_send_goodput_metrics_to_gcp_exception( 464 | self, 465 | mock_logging_client, 466 | mock_summary_writer, 467 | mock_metric_service_client, 468 | ): 469 | mock_client = MagicMock() 470 | mock_client.create_time_series.side_effect = Exception('Test Exception') 471 | mock_metric_service_client.return_value = mock_client 472 | mock_logging_client.return_value = MagicMock() 473 | mock_summary_writer.return_value = MagicMock() 474 | 475 | gcp_options = GCPOptions( 476 | enable_gcp_goodput_metrics=True, 477 | project_id='test-project', 478 | location='test-location', 479 | acc_type='test-acc-type', 480 | replica_id='test-replica-id', 481 | ) 482 | 483 | goodput_monitor = GoodputMonitor( 484 | self.job_name, 485 | self.logger_name, 486 | self.tensorboard_dir, 487 | upload_interval=_TEST_UPLOAD_INTERVAL, 488 | monitoring_enabled=True, 489 | gcp_options=gcp_options, 490 | ) 491 | 492 | # Mock the get_job_goodput_details to return test data 493 | goodput_monitor._goodput_calculator.get_job_goodput_details = MagicMock( 494 | return_value={ 495 | MetricType.GOODPUT_TIME.value: { 496 | GoodputType.TOTAL: 10.0, 497 | }, 498 | MetricType.BADPUT_TIME.value: { 499 | BadputType.DATA_LOADING_SYNC: 2.0, 500 | }, 501 | MetricType.DISRUPTION_COUNT.value: 0, 502 | MetricType.MAX_PRODUCTIVE_STEP.value: 2, 503 | MetricType.TOTAL_ELAPSED_TIME.value: 20.0, 504 | MetricType.STEP_TIME_DEVIATION.value: { 505 | 0: 1.0, 506 | 1: 1.0, 507 | 2: 1.0, 508 | }, 509 | MetricType.IDEAL_STEP_TIME.value: 1.0, 510 | } 511 | ) 512 | 513 | goodput_monitor._upload_goodput_metrics_to_gcm( 514 | goodput_monitor._goodput_calculator.get_job_goodput_details() 515 | ) 516 | 517 | # Verify that create_time_series was called, even if it raised an exception 518 | mock_client.create_time_series.assert_called_once() 519 | 520 | @patch('google.cloud.monitoring_v3.MetricServiceClient') 521 | @patch('tensorboardX.writer.SummaryWriter') 522 | @patch('google.cloud.logging.Client') 523 | def test_send_goodput_metrics_to_gcp_exclusion( 524 | self, 525 | mock_logging_client, 526 | mock_summary_writer, 527 | mock_metric_service_client 528 | ): 529 | mock_client = MagicMock() 530 | mock_metric_service_client.return_value = mock_client 531 | mock_logging_client.return_value = MagicMock() 532 | mock_summary_writer.return_value = MagicMock() 533 | 534 | gcp_options = GCPOptions( 535 | enable_gcp_goodput_metrics=True, 536 | project_id='test-project', 537 | location='test-location', 538 | acc_type='test-acc-type', 539 | replica_id='test-replica-id', 540 | ) 541 | 542 | goodput_monitor = GoodputMonitor( 543 | self.job_name, 544 | self.logger_name, 545 | self.tensorboard_dir, 546 | upload_interval=_TEST_UPLOAD_INTERVAL, 547 | monitoring_enabled=True, 548 | gcp_options=gcp_options, 549 | ) 550 | 551 | # Mock the get_job_goodput_details to return test data, including an 552 | # excluded type 553 | goodput_monitor._goodput_calculator.get_job_goodput_details = MagicMock( 554 | return_value={ 555 | MetricType.GOODPUT_TIME.value: { 556 | GoodputType.TOTAL: 10.0, 557 | }, 558 | MetricType.BADPUT_TIME.value: { 559 | BadputType.TPU_INITIALIZATION: 2.0, 560 | BadputType.DATA_LOADING_SYNC: 1.0, 561 | BadputType.DATA_LOADING_ASYNC: ( 562 | 3.0 563 | ), # DATA_LOADING_ASYNC is in ACTIVITY_EXCLUSION_LIST 564 | }, 565 | MetricType.DISRUPTION_COUNT.value: 0, 566 | MetricType.MAX_PRODUCTIVE_STEP.value: 2, 567 | MetricType.TOTAL_ELAPSED_TIME.value: 20.0, 568 | MetricType.STEP_TIME_DEVIATION.value: { 569 | 0: 1.0, 570 | 1: 1.0, 571 | 2: 1.0, 572 | }, 573 | MetricType.IDEAL_STEP_TIME.value: 1.0, 574 | } 575 | ) 576 | 577 | goodput_monitor._upload_goodput_metrics_to_gcm( 578 | goodput_monitor._goodput_calculator.get_job_goodput_details() 579 | ) 580 | 581 | # Verify that create_time_series was called with the correct data, 582 | # excluding DATA_LOADING_ASYNC 583 | expected_calls = [ 584 | mock.call.create_time_series( 585 | name='projects/test-project', 586 | time_series=[ 587 | self._create_timeseries( 588 | 'compute.googleapis.com/workload/goodput_time', 589 | { 590 | 'goodput_source': 'TOTAL', 591 | 'accelerator_type': 'test-acc-type', 592 | }, 593 | 10.0, 594 | ) 595 | ], 596 | ), 597 | mock.call.create_time_series( 598 | name='projects/test-project', 599 | time_series=[ 600 | self._create_timeseries( 601 | 'compute.googleapis.com/workload/badput_time', 602 | { 603 | 'badput_source': 'TPU_INITIALIZATION', 604 | 'accelerator_type': 'test-acc-type', 605 | }, 606 | 2.0, 607 | ) 608 | ], 609 | ), 610 | mock.call.create_time_series( 611 | name='projects/test-project', 612 | time_series=[ 613 | self._create_timeseries( 614 | 'compute.googleapis.com/workload/badput_time', 615 | { 616 | 'badput_source': 'DATA_LOADING_SYNC', 617 | 'accelerator_type': 'test-acc-type', 618 | }, 619 | 1.0, 620 | ) 621 | ], 622 | ), 623 | mock.call.create_time_series( 624 | name='projects/test-project', 625 | time_series=[ 626 | self._create_timeseries( 627 | 'compute.googleapis.com/workload/disruptions', 628 | { 629 | 'accelerator_type': 'test-acc-type', 630 | 'window_type': 'CUMULATIVE', 631 | }, 632 | 0, 633 | ) 634 | ], 635 | ), 636 | mock.call.create_time_series( 637 | name='projects/test-project', 638 | time_series=[ 639 | self._create_timeseries( 640 | 'compute.googleapis.com/workload/max_productive_steps', 641 | { 642 | 'accelerator_type': 'test-acc-type', 643 | }, 644 | 2, 645 | ) 646 | ], 647 | ), 648 | mock.call.create_time_series( 649 | name='projects/test-project', 650 | time_series=[ 651 | self._create_timeseries( 652 | 'compute.googleapis.com/workload/total_elapsed_time', 653 | { 654 | 'accelerator_type': 'test-acc-type', 655 | 'window_type': 'CUMULATIVE', 656 | }, 657 | 20.0, 658 | ) 659 | ], 660 | ), 661 | mock.call.create_time_series( 662 | name='projects/test-project', 663 | time_series=[ 664 | self._create_timeseries( 665 | 'compute.googleapis.com/workload/step_time_deviation', 666 | { 667 | 'accelerator_type': 'test-acc-type', 668 | }, 669 | 1.0, 670 | ) 671 | ], 672 | ), 673 | mock.call.create_time_series( 674 | name='projects/test-project', 675 | time_series=[ 676 | self._create_timeseries( 677 | 'compute.googleapis.com/workload/performance', 678 | { 679 | 'accelerator_type': 'test-acc-type', 680 | }, 681 | 1.0, 682 | ) 683 | ], 684 | ), 685 | ] 686 | 687 | actual_calls = mock_client.create_time_series.call_args_list 688 | 689 | # Verify each call individually 690 | for expected_call in expected_calls: 691 | self.assertTrue( 692 | any( 693 | self._compare_calls_ignore_time_series(expected_call, actual) 694 | for actual in actual_calls 695 | ), 696 | f'Expected call not found: {expected_call}', 697 | ) 698 | # Verify unexpected calls are not made 699 | for actual_call in actual_calls: 700 | for ts in actual_call.kwargs.get('time_series', []): 701 | if ( 702 | 'badput_source' in ts.metric.labels 703 | and ts.metric.labels['badput_source'] == 'DATA_LOADING_ASYNC' 704 | ): 705 | self.fail(f'Unexpected call found: {ts}') 706 | 707 | @patch('google.cloud.monitoring_v3.MetricServiceClient') 708 | @patch('tensorboardX.writer.SummaryWriter') 709 | @patch('google.cloud.logging.Client') 710 | def test_send_interval_goodput_metrics_to_gcp( 711 | self, 712 | mock_logging_client, 713 | mock_summary_writer, 714 | mock_metric_service_client, 715 | ): 716 | mock_client = MagicMock() 717 | mock_metric_service_client.return_value = mock_client 718 | mock_logging_client.return_value = MagicMock() 719 | mock_summary_writer.return_value = MagicMock() 720 | 721 | gcp_options = GCPOptions( 722 | enable_gcp_goodput_metrics=True, 723 | project_id='test-project', 724 | location='test-location', 725 | acc_type='test-acc-type', 726 | replica_id='test-replica-id', 727 | ) 728 | 729 | goodput_monitor = GoodputMonitor( 730 | self.job_name, 731 | self.logger_name, 732 | self.tensorboard_dir, 733 | upload_interval=_TEST_UPLOAD_INTERVAL, 734 | monitoring_enabled=True, 735 | gcp_options=gcp_options, 736 | ) 737 | 738 | # Mock the get_job_goodput_details to return test data 739 | goodput_monitor._goodput_calculator.get_interval_metric_details = MagicMock( 740 | return_value={ 741 | IntervalMetricType.INTERVAL_GOODPUT.value: { 742 | GoodputType.TOTAL: 90.0, 743 | }, 744 | IntervalMetricType.INTERVAL_BADPUT.value: { 745 | BadputType.TPU_INITIALIZATION: 2.0, 746 | BadputType.DATA_LOADING_SYNC: 8.0, 747 | }, 748 | IntervalMetricType.INTERVAL_SIZE.value: 100, 749 | } 750 | ) 751 | 752 | goodput_monitor._upload_interval_goodput_metrics_to_gcm( 753 | goodput_monitor._goodput_calculator.get_interval_metric_details() 754 | ) 755 | 756 | expected_calls = [ 757 | mock.call.create_time_series( 758 | name='projects/test-project', 759 | time_series=[ 760 | self._create_timeseries( 761 | 'compute.googleapis.com/workload/interval_goodput', 762 | { 763 | 'goodput_source': 'TOTAL', 764 | 'accelerator_type': 'test-acc-type', 765 | 'rolling_window_size': '100', 766 | }, 767 | 90.0, 768 | ) 769 | ], 770 | ), 771 | mock.call.create_time_series( 772 | name='projects/test-project', 773 | time_series=[ 774 | self._create_timeseries( 775 | 'compute.googleapis.com/workload/interval_badput', 776 | { 777 | 'badput_source': 'TPU_INITIALIZATION', 778 | 'accelerator_type': 'test-acc-type', 779 | 'rolling_window_size': '100', 780 | }, 781 | 2.0, 782 | ) 783 | ], 784 | ), 785 | mock.call.create_time_series( 786 | name='projects/test-project', 787 | time_series=[ 788 | self._create_timeseries( 789 | 'compute.googleapis.com/workload/interval_badput', 790 | { 791 | 'badput_source': 'DATA_LOADING_SYNC', 792 | 'accelerator_type': 'test-acc-type', 793 | 'rolling_window_size': '100', 794 | }, 795 | 8.0, 796 | ) 797 | ], 798 | ), 799 | ] 800 | 801 | actual_calls = mock_client.create_time_series.call_args_list 802 | 803 | # Verify each call individually 804 | for expected_call in expected_calls: 805 | self.assertTrue( 806 | any( 807 | self._compare_calls_ignore_time_series(expected_call, actual) 808 | for actual in actual_calls 809 | ), 810 | f'Expected call not found: {expected_call}', 811 | ) 812 | 813 | @patch('google.cloud.monitoring_v3.MetricServiceClient') 814 | @patch('tensorboardX.writer.SummaryWriter') 815 | @patch('google.cloud.logging.Client') 816 | def test_send_goodput_metrics_custom_sync_events( 817 | self, mock_logging_client, mock_summary_writer, mock_metric_service_client 818 | ): 819 | mock_client = MagicMock() 820 | mock_metric_service_client.return_value = mock_client 821 | mock_logging_client.return_value = MagicMock() 822 | mock_summary_writer.return_value = MagicMock() 823 | 824 | gcp_options = GCPOptions( 825 | enable_gcp_goodput_metrics=True, 826 | project_id='test-project', 827 | location='test-location', 828 | acc_type='test-acc-type', 829 | replica_id='test-replica-id', 830 | ) 831 | 832 | goodput_monitor = GoodputMonitor( 833 | self.job_name, 834 | self.logger_name, 835 | self.tensorboard_dir, 836 | upload_interval=_TEST_UPLOAD_INTERVAL, 837 | monitoring_enabled=True, 838 | gcp_options=gcp_options, 839 | ) 840 | 841 | # Mock the get_job_goodput_details to return test data, including an 842 | # excluded type 843 | goodput_monitor._goodput_calculator.get_job_goodput_details = MagicMock( 844 | return_value={ 845 | MetricType.GOODPUT_TIME.value: { 846 | GoodputType.TOTAL: 10.0, 847 | }, 848 | MetricType.BADPUT_TIME.value: { 849 | BadputType.TPU_INITIALIZATION: 2.0, 850 | BadputType.DATA_LOADING_SYNC: 1.0, 851 | BadputType.CUSTOM_BADPUT_EVENTS: { 852 | 'EVAL_STEP': 3.0, 853 | 'SDC_COMPILATION': 4.0, 854 | }, 855 | }, 856 | MetricType.DISRUPTION_COUNT.value: 0, 857 | MetricType.MAX_PRODUCTIVE_STEP.value: 2, 858 | MetricType.TOTAL_ELAPSED_TIME.value: 20.0, 859 | MetricType.STEP_TIME_DEVIATION.value: { 860 | 0: 1.0, 861 | 1: 1.0, 862 | 2: 1.0, 863 | }, 864 | MetricType.IDEAL_STEP_TIME.value: 1.0, 865 | } 866 | ) 867 | 868 | goodput_monitor._upload_goodput_metrics_to_gcm( 869 | goodput_monitor._goodput_calculator.get_job_goodput_details() 870 | ) 871 | 872 | expected_calls = [ 873 | mock.call.create_time_series( 874 | name='projects/test-project', 875 | time_series=[ 876 | self._create_timeseries( 877 | 'compute.googleapis.com/workload/goodput_time', 878 | { 879 | 'goodput_source': 'TOTAL', 880 | 'accelerator_type': 'test-acc-type', 881 | }, 882 | 10.0, 883 | ) 884 | ], 885 | ), 886 | mock.call.create_time_series( 887 | name='projects/test-project', 888 | time_series=[ 889 | self._create_timeseries( 890 | 'compute.googleapis.com/workload/badput_time', 891 | { 892 | 'badput_source': 'TPU_INITIALIZATION', 893 | 'accelerator_type': 'test-acc-type', 894 | }, 895 | 2.0, 896 | ) 897 | ], 898 | ), 899 | mock.call.create_time_series( 900 | name='projects/test-project', 901 | time_series=[ 902 | self._create_timeseries( 903 | 'compute.googleapis.com/workload/badput_time', 904 | { 905 | 'badput_source': 'DATA_LOADING_SYNC', 906 | 'accelerator_type': 'test-acc-type', 907 | }, 908 | 1.0, 909 | ) 910 | ], 911 | ), 912 | mock.call.create_time_series( 913 | name='projects/test-project', 914 | time_series=[ 915 | self._create_timeseries( 916 | 'compute.googleapis.com/workload/disruptions', 917 | { 918 | 'accelerator_type': 'test-acc-type', 919 | 'window_type': 'CUMULATIVE', 920 | }, 921 | 0, 922 | ) 923 | ], 924 | ), 925 | mock.call.create_time_series( 926 | name='projects/test-project', 927 | time_series=[ 928 | self._create_timeseries( 929 | 'compute.googleapis.com/workload/max_productive_steps', 930 | { 931 | 'accelerator_type': 'test-acc-type', 932 | }, 933 | 2, 934 | ) 935 | ], 936 | ), 937 | mock.call.create_time_series( 938 | name='projects/test-project', 939 | time_series=[ 940 | self._create_timeseries( 941 | 'compute.googleapis.com/workload/total_elapsed_time', 942 | { 943 | 'accelerator_type': 'test-acc-type', 944 | 'window_type': 'CUMULATIVE', 945 | }, 946 | 20.0, 947 | ) 948 | ], 949 | ), 950 | mock.call.create_time_series( 951 | name='projects/test-project', 952 | time_series=[ 953 | self._create_timeseries( 954 | 'compute.googleapis.com/workload/step_time_deviation', 955 | { 956 | 'accelerator_type': 'test-acc-type', 957 | }, 958 | 1.0, 959 | ) 960 | ], 961 | ), 962 | mock.call.create_time_series( 963 | name='projects/test-project', 964 | time_series=[ 965 | self._create_timeseries( 966 | 'compute.googleapis.com/workload/performance', 967 | { 968 | 'accelerator_type': 'test-acc-type', 969 | }, 970 | 1.0, 971 | ) 972 | ], 973 | ), 974 | ] 975 | 976 | actual_calls = mock_client.create_time_series.call_args_list 977 | 978 | # Verify each call individually 979 | for expected_call in expected_calls: 980 | self.assertTrue( 981 | any( 982 | self._compare_calls_ignore_time_series(expected_call, actual_call) 983 | for actual_call in actual_calls 984 | ), 985 | f'Expected call not found: {expected_call}', 986 | ) 987 | 988 | @patch( 989 | 'cloud_goodput.ml_goodput_measurement.src.monitoring.GoodputMonitor._final_interval_goodput_query_and_upload' 990 | ) 991 | @patch( 992 | 'cloud_goodput.ml_goodput_measurement.src.monitoring.GoodputMonitor._final_step_deviation_query_and_upload' 993 | ) 994 | @patch( 995 | 'cloud_goodput.ml_goodput_measurement.src.monitoring.GoodputMonitor._final_goodput_query_and_upload' 996 | ) 997 | async def test_goodput_monitor_final_query_and_upload( 998 | self, 999 | mock_final_goodput_query_and_upload, 1000 | mock_final_step_deviation_query_and_upload, 1001 | mock_final_interval_goodput_query_and_upload, 1002 | ): 1003 | mock_final_goodput_query_and_upload.return_value = MagicMock() 1004 | mock_final_step_deviation_query_and_upload.return_value = MagicMock() 1005 | mock_final_interval_goodput_query_and_upload.return_value = MagicMock() 1006 | goodput_monitor = monitoring.GoodputMonitor( 1007 | self.job_name, 1008 | self.logger_name, 1009 | self.tensorboard_dir, 1010 | upload_interval=_TEST_UPLOAD_INTERVAL, 1011 | monitoring_enabled=True, 1012 | ) 1013 | goodput_monitor.__del__() 1014 | mock_final_goodput_query_and_upload.assert_called_once() 1015 | mock_final_step_deviation_query_and_upload.assert_called_once() 1016 | mock_final_interval_goodput_query_and_upload.assert_called_once() 1017 | 1018 | 1019 | if __name__ == '__main__': 1020 | absltest.main() 1021 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | [project] 16 | name = "ml_goodput_measurement" 17 | version = "0.0.11" 18 | authors = [ 19 | { name="Cloud TPU Team", email="cloud-tpu-eng@google.com" }, 20 | ] 21 | description = "Package to monitor Goodput, Badput and other metrics of ML workloads." 22 | readme = "README.md" 23 | requires-python = ">=3.8" 24 | license = {text = "Apache-2.0"} 25 | classifiers = [ 26 | "Programming Language :: Python :: 3.8", 27 | "Programming Language :: Python :: 3.9", 28 | "Programming Language :: Python :: 3.10", 29 | "Programming Language :: Python :: 3.11", 30 | ] 31 | keywords = [] 32 | 33 | # pip dependencies installed with `pip install -e .` 34 | dependencies = [ 35 | "google-api-core>=2.24.1", 36 | "google-cloud-logging>=3.5.0", 37 | "google-cloud-monitoring>=2.20.0", 38 | "numpy", 39 | "requests", 40 | "scipy", 41 | "tensorboardx", 42 | "urllib3", 43 | ] 44 | 45 | [project.urls] 46 | "Homepage" = "https://github.com/AI-Hypercomputer/ml-goodput-measurement" 47 | "Bug Tracker" = "https://github.com/AI-Hypercomputer/ml-goodput-measurement/issues" 48 | 49 | [build-system] 50 | # Build system specify which backend is used to build/install the project 51 | requires = ["flit_core >=3.8,<4"] 52 | build-backend = "flit_core.buildapi" 53 | 54 | [tool.flit.sdist] 55 | # Flit specific options (files to exclude from the PyPI package) 56 | exclude = [ 57 | # Do not release tests files on PyPI 58 | "tests/*_test.py", 59 | ] --------------------------------------------------------------------------------