├── .gitignore ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── SECURITY.md ├── mem_remat.png └── remat.patch /.gitignore: -------------------------------------------------------------------------------- 1 | # Prerequisites 2 | *.d 3 | 4 | # Object files 5 | *.o 6 | *.ko 7 | *.obj 8 | *.elf 9 | 10 | # Linker output 11 | *.ilk 12 | *.map 13 | *.exp 14 | 15 | # Precompiled Headers 16 | *.gch 17 | *.pch 18 | 19 | # Libraries 20 | *.lib 21 | *.a 22 | *.la 23 | *.lo 24 | 25 | # Shared objects (inc. Windows DLLs) 26 | *.dll 27 | *.so 28 | *.so.* 29 | *.dylib 30 | 31 | # Executables 32 | *.exe 33 | *.out 34 | *.app 35 | *.i*86 36 | *.x86_64 37 | *.hex 38 | 39 | # Debug files 40 | *.dSYM/ 41 | *.su 42 | *.idb 43 | *.pdb 44 | 45 | # Kernel Module Compile Results 46 | *.mod* 47 | *.cmd 48 | .tmp_versions/ 49 | modules.order 50 | Module.symvers 51 | Mkfile.old 52 | dkms.conf 53 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow XLA Rematerialization 2 | 3 | This project brings an interface to use rematerialization in the TensorFlow XLA Compiler. 4 | 5 | # Installing 6 | 7 | This patch adds rematerialization support to TensorFlow 2.0 XLA branch. To install it, execute the following commands: 8 | 9 | ``` 10 | git clone https://github.com/tensorflow/tensorflow 11 | git checkout 64c3d382cadf7bbe8e7e99884bede8284ff67f56 12 | cd tensorflow 13 | git clone https://github.com/microsoft/tensorflow-rematerialization 14 | git apply -v tensorflow-rematerialization/remat.patch 15 | ``` 16 | 17 | After cloning TensorFlow and applying the remat patch, you can compile and install it by following the TensorFlow documentation: https://www.tensorflow.org/install/source 18 | 19 | # Using 20 | 21 | You can control the rematerialization in XLA by passing different flags to the environment variable XLA_FLAGS. The available flags are: 22 | 23 | ``` 24 | • --xla_use_hlo_rematerialization  25 | ○ This flag enables the rematerialization to pass after all optimizations and just before code emission. This flag is necessary for the use of any other of the following flags. 26 | 27 | • --xla_rematerialization_mem_limit={NUMBER}  28 | ○ This flag sets the memory budget for your application. 29 | The remat heuristic will try to fit the model in that amount of memory. 30 | If the memory budget is larger than the model size no rematerialization is applied. The parameter {NUMBER} is in bytes. 31 | 32 | • --xla_rematerialization_scheduler={SCHEDULER_NAME} 33 | ○ This flag allows us to choose which scheduler to use just before rematerialization. The scheduling can impact in the performance of the heuristics. 34 | Four options are available: 35 | (1) if no scheduler is set the default TF behavior is to apply the three next heuristics and select the best in terms of performance (not considering the remat here); 36 | (2) "postorder"; 37 | (3) "DFS" and; 38 | (4) "list".  39 | 40 | • --xla_rematerialization_algorithm={HEURISTIC_NAME} 41 | ○ This flag sets which heuristics to use to rematerialize. 42 | Four options are available: 43 | (1) "standard" which uses the original implementation of the rematerialization that can be found inside the TF source tree; 44 | (2) "compress" which tries to reorder dimensions of tensors in order to reduce their representation size (this is also an implementation that is inside TF, but it is not a rematerialization); 45 | (3) "standardcompress" which applies both the standard remat and the compress technique; and, finally, (4) "path", our approach to remat which recursively tries to rematerialize paths and then derematerialize articulation operations while still in the memory budget. 46 | 47 | • --xla_rematerialization_small_node_limit={NUMBER} 48 | ○ This flag sets the minimum size that an HLO node has to be considered for rematerialization. The expect {NUMBER} is in MiB and 0 disables this feature. Having a limit of 1 MiB showed to increase significantly the performance of both standard and path heuristics. 49 | 50 | • --xla_rematerialization_disable_cuda 51 | ○ This flag disables part of the CUDA fusions in the HLO graph making it easier to apply rematerialization as fewer side effect nodes will exist. 52 | 53 | • --xla_rematerialization_dump_dot 54 | ○ Dumps a dot graph of the HLO before and after the rematerialization. 55 | 56 | • --xla_rematerialization_dump_memlog 57 | ○ Dumps a log of the memory use and remat decisions information for each HLO instruction. 58 | ``` 59 | 60 | Note: TF_XLA_FLAGS="--tf_xla_auto_jit=2" needs to be set to activate XLA compiler. 61 | 62 | Example: 63 | 64 | ``` 65 | TF_XLA_FLAGS="--tf_xla_auto_jit=2" XLA_FLAGS="--xla_dump_to=dump --xla_dump_hlo_as_text --xla_use_hlo_rematerialization --xla_rematerialization_mem_limit=1073741824 --xla_rematerialization_algorithm=path --xla_rematerialization_small_node_limit=1" python resnet_cifar_main.py 66 | ``` 67 | 68 | # Results 69 | 70 | Using rematerialization in TensorFlow XLA can help reduce the amount of memory necessary to run a model making it fit in your accelerator. 71 | 72 | ![memory before and after using XLA rematerialization](https://raw.githubusercontent.com/microsoft/tensorflow-rematerialization/master/mem_remat.png?token=AAGIC227NYZFIQKQ3YIE3BS6AYHOI) 73 | 74 | 75 | # Contributing 76 | 77 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 78 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 79 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 80 | 81 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 82 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 83 | provided by the bot. You will only need to do this once across all repos using our CLA. 84 | 85 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 86 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 87 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 88 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets Microsoft's [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)) of a security vulnerability, please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /mem_remat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-rematerialization/24fb20a361d0de4d86abb7f91f4626c174250758/mem_remat.png -------------------------------------------------------------------------------- /remat.patch: -------------------------------------------------------------------------------- 1 | diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc 2 | index 93ae3d2..39a6266 100644 3 | --- a/tensorflow/compiler/xla/debug_options_flags.cc 4 | +++ b/tensorflow/compiler/xla/debug_options_flags.cc 5 | @@ -59,6 +59,16 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { 6 | 7 | opts.set_xla_allow_excess_precision(true); 8 | opts.set_xla_force_host_platform_device_count(1); 9 | + 10 | + opts.set_xla_use_hlo_rematerialization(false); 11 | + opts.set_xla_rematerialization_mem_limit("0"); 12 | + opts.set_xla_rematerialization_scheduler("default"); 13 | + opts.set_xla_rematerialization_algorithm("standard"); 14 | + opts.set_xla_rematerialization_small_node_limit(1); 15 | + opts.set_xla_rematerialization_disable_cuda(false); 16 | + opts.set_xla_rematerialization_dump_dot(false); 17 | + opts.set_xla_rematerialization_dump_memlog(false); 18 | + 19 | return opts; 20 | } 21 | 22 | @@ -440,6 +450,54 @@ static void AllocateFlags() { 23 | "--xla_fuel=PASS1=NUM1,PASS2=NUM2,..."), 24 | 25 | tensorflow::Flag( 26 | + "xla_use_hlo_rematerialization", 27 | + bool_setter_for(&DebugOptions::set_xla_use_hlo_rematerialization), 28 | + flag_values->xla_use_hlo_rematerialization(), 29 | + "Enables HLO rematerialization heuristic which tries either to reduce" 30 | + " memory consunpution as much as possible or until below a limit " 31 | + "setted by --xla_rematerialization_mem_limit"), 32 | + tensorflow::Flag( 33 | + "xla_rematerialization_mem_limit", 34 | + string_setter_for(&DebugOptions::set_xla_rematerialization_mem_limit), 35 | + flag_values->xla_rematerialization_mem_limit(), 36 | + "Sets a memory limit goal (in bytes) to the HLO rematerialization " 37 | + "heuristic."), 38 | + tensorflow::Flag( 39 | + "xla_rematerialization_scheduler", 40 | + string_setter_for(&DebugOptions::set_xla_rematerialization_scheduler), 41 | + flag_values->xla_rematerialization_scheduler(), 42 | + "Sets the scheduler to be used just before rematerialization." 43 | + " Options are: default, postorder, DFS, and list."), 44 | + tensorflow::Flag( 45 | + "xla_rematerialization_algorithm", 46 | + string_setter_for(&DebugOptions::set_xla_rematerialization_algorithm), 47 | + flag_values->xla_rematerialization_algorithm(), 48 | + "Sets the rematerialization or compression technique to be used." 49 | + " Options are: standard, compress, standardcompress, and path."), 50 | + tensorflow::Flag( 51 | + "xla_rematerialization_small_node_limit", 52 | + int32_setter_for(&DebugOptions::set_xla_rematerialization_small_node_limit), 53 | + flag_values->xla_rematerialization_small_node_limit(), 54 | + "Sets the minimum size (in MiB) that a candidate to rematerialization" 55 | + " needs to have."), 56 | + tensorflow::Flag( 57 | + "xla_rematerialization_disable_cuda", 58 | + bool_setter_for(&DebugOptions::set_xla_rematerialization_disable_cuda), 59 | + flag_values->xla_rematerialization_disable_cuda(), 60 | + "Disable cuda picking fusion optimization (this can improve remat)."), 61 | + tensorflow::Flag( 62 | + "xla_rematerialization_dump_dot", 63 | + bool_setter_for(&DebugOptions::set_xla_rematerialization_dump_dot), 64 | + flag_values->xla_rematerialization_dump_dot(), 65 | + "Dump dot representation of the HLO graph."), 66 | + tensorflow::Flag( 67 | + "xla_rematerialization_dump_memlog", 68 | + bool_setter_for(&DebugOptions::set_xla_rematerialization_dump_memlog), 69 | + flag_values->xla_rematerialization_dump_memlog(), 70 | + "Dump mem log about memory usage after the rematerialization."), 71 | + 72 | + 73 | + tensorflow::Flag( 74 | "xla_dump_to", string_setter_for(&DebugOptions::set_xla_dump_to), 75 | flag_values->xla_dump_to(), 76 | "Directory into which debugging data is written. If not specified " 77 | diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD 78 | index 581d358..bab48f1 100644 79 | --- a/tensorflow/compiler/xla/service/BUILD 80 | +++ b/tensorflow/compiler/xla/service/BUILD 81 | @@ -2957,6 +2957,7 @@ cc_library( 82 | ":flatten_call_graph", 83 | ":hlo", 84 | ":hlo_dce", 85 | + ":hlo_cost_analysis", 86 | ":hlo_memory_scheduler", 87 | ":hlo_ordering", 88 | ":logical_buffer", 89 | diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD 90 | index 7d65624..9e8db31 100644 91 | --- a/tensorflow/compiler/xla/service/cpu/BUILD 92 | +++ b/tensorflow/compiler/xla/service/cpu/BUILD 93 | @@ -130,6 +130,7 @@ cc_library( 94 | "//tensorflow/compiler/xla/service:hlo_pass_pipeline", 95 | "//tensorflow/compiler/xla/service:hlo_proto", 96 | "//tensorflow/compiler/xla/service:hlo_proto_util", 97 | + "//tensorflow/compiler/xla/service:hlo_rematerialization", 98 | "//tensorflow/compiler/xla/service:hlo_memory_scheduler", 99 | "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", 100 | "//tensorflow/compiler/xla/service:hlo_verifier", 101 | diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc 102 | index acafa2c..85d388f 100644 103 | --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc 104 | +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc 105 | @@ -53,6 +53,7 @@ limitations under the License. 106 | #include "tensorflow/compiler/xla/service/call_inliner.h" 107 | #include "tensorflow/compiler/xla/service/cholesky_expander.h" 108 | #include "tensorflow/compiler/xla/service/conditional_simplifier.h" 109 | +#include "tensorflow/compiler/xla/service/hlo_rematerialization.h" 110 | #include "tensorflow/compiler/xla/service/conditional_to_select.h" 111 | #include "tensorflow/compiler/xla/service/convolution_group_converter.h" 112 | #include "tensorflow/compiler/xla/service/copy_insertion.h" 113 | @@ -601,6 +602,68 @@ struct OrcJITPostCompilationHook { 114 | 115 | } // namespace 116 | 117 | +// Return the byte size of the top-level buffer of the given shape. 118 | +static int64 ByteSizeOf(const Shape& shape) { 119 | + return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); 120 | +} 121 | + 122 | +static StatusOr ChooseCompactLayoutForShape(const Shape& shape) { 123 | + Shape result = shape; 124 | + Layout layout = result.layout(); 125 | + int64 most_minor_index = layout.minor_to_major()[0]; 126 | + int64 second_minor_index = layout.minor_to_major()[1]; 127 | + int64 most_minor = result.dimensions(most_minor_index); 128 | + int64 second_minor = result.dimensions(second_minor_index); 129 | + if (most_minor < second_minor) { 130 | + result.set_dimensions(most_minor_index, second_minor); 131 | + result.set_dimensions(second_minor_index, most_minor); 132 | + } 133 | + return result; 134 | +} 135 | + 136 | +StatusOr RunHloRematerialization(int64 memory_limit_bytes, 137 | + HloModule* module) { 138 | + 139 | + auto sch = DefaultMemoryScheduler; 140 | + string scheduler_option = 141 | + module->config().debug_options().xla_rematerialization_scheduler(); 142 | + 143 | + if (scheduler_option == "postorder") { 144 | + sch = PostOrderMemoryScheduler; 145 | + } else if (scheduler_option == "DFS") { 146 | + sch = DFSMemoryScheduler; 147 | + } else if (scheduler_option == "list") { 148 | + sch = ListMemoryScheduler; 149 | + } 150 | + 151 | + HloMemoryScheduler scheduler( 152 | + [](const BufferValue& buffer) { return ByteSizeOf(buffer.shape()); }, 153 | + ComputationSchedulerToModuleScheduler( 154 | + sch 155 | + )); 156 | + 157 | + TF_RETURN_IF_ERROR(scheduler.Run(module).status()); 158 | + 159 | + RematerializationAlg alg = kStandardAlg; 160 | + string algorithm_option = 161 | + module->config().debug_options().xla_rematerialization_algorithm(); 162 | + 163 | + if (algorithm_option == "compress") { 164 | + alg = kCompressAlg; 165 | + } else if (algorithm_option == "standardcompress") { 166 | + alg = kStandardAndCompressAlg; 167 | + } else if (algorithm_option == "path") { 168 | + alg = kPathAlg; 169 | + } 170 | + 171 | + DumpHloModuleIfEnabled(*module, "before_remat"); 172 | + HloRematerialization remat(ByteSizeOf, memory_limit_bytes, 173 | + /*sizes=*/nullptr, ChooseCompactLayoutForShape); 174 | + remat.setAlgorithm(alg); 175 | + return remat.Run(module); 176 | +} 177 | + 178 | + 179 | StatusOr> CpuCompiler::RunBackend( 180 | std::unique_ptr module, se::StreamExecutor* stream_exec, 181 | se::DeviceMemoryAllocator* /*device_allocator*/) { 182 | @@ -613,6 +676,22 @@ StatusOr> CpuCompiler::RunBackend( 183 | std::call_once(llvm_command_line_options_initialized, 184 | &llvm_ir::InitializeLLVMCommandLineOptions, module->config()); 185 | 186 | + // Rematerialization needs to be apply after all optimizations 187 | + if (module->config().debug_options().xla_use_hlo_rematerialization()) { 188 | + string mem_limit_s = 189 | + module->config().debug_options().xla_rematerialization_mem_limit(); 190 | + 191 | + LOG(WARNING) << "Starting rematerialization of "<< 192 | + module->name() << " with " << mem_limit_s << " bytes as mem limit"; 193 | + 194 | + int64_t mem_limit_u = std::stoull(mem_limit_s); 195 | + 196 | + StatusOr remat_result = 197 | + RunHloRematerialization(mem_limit_u, module.get()); 198 | + 199 | + TF_RETURN_IF_ERROR(remat_result.status()); 200 | + } 201 | + 202 | ModuleHook pre_optimization_ir_hook; 203 | ModuleHook post_optimization_ir_hook; 204 | std::tie(pre_optimization_ir_hook, post_optimization_ir_hook) = 205 | diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD 206 | index 866df46..59f0fe0 100644 207 | --- a/tensorflow/compiler/xla/service/gpu/BUILD 208 | +++ b/tensorflow/compiler/xla/service/gpu/BUILD 209 | @@ -1008,6 +1008,7 @@ cc_library( 210 | "//tensorflow/compiler/xla/service:hlo", 211 | "//tensorflow/compiler/xla/service:hlo_constant_folding", 212 | "//tensorflow/compiler/xla/service:hlo_cse", 213 | + "//tensorflow/compiler/xla/service:hlo_rematerialization", 214 | "//tensorflow/compiler/xla/service:hlo_dataflow_analysis", 215 | "//tensorflow/compiler/xla/service:hlo_dce", 216 | "//tensorflow/compiler/xla/service:hlo_element_type_converter", 217 | diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc 218 | index 9dda327..f5ca615 100644 219 | --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc 220 | +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc 221 | @@ -78,6 +78,7 @@ limitations under the License. 222 | #include "tensorflow/compiler/xla/service/hlo_cse.h" 223 | #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" 224 | #include "tensorflow/compiler/xla/service/hlo_dce.h" 225 | +#include "tensorflow/compiler/xla/service/hlo_rematerialization.h" 226 | #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" 227 | #include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h" 228 | #include "tensorflow/compiler/xla/service/hlo_instruction.h" 229 | @@ -259,6 +260,67 @@ bool MaybeLoadPtxFromFile(const HloModule* module, std::string* ptx) { 230 | 231 | } // namespace 232 | 233 | +// Return the byte size of the top-level buffer of the given shape. 234 | +static int64 ByteSizeOf(const Shape& shape) { 235 | + return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); 236 | +} 237 | + 238 | +static StatusOr ChooseCompactLayoutForShape(const Shape& shape) { 239 | + Shape result = shape; 240 | + Layout layout = result.layout(); 241 | + int64 most_minor_index = layout.minor_to_major()[0]; 242 | + int64 second_minor_index = layout.minor_to_major()[1]; 243 | + int64 most_minor = result.dimensions(most_minor_index); 244 | + int64 second_minor = result.dimensions(second_minor_index); 245 | + if (most_minor < second_minor) { 246 | + result.set_dimensions(most_minor_index, second_minor); 247 | + result.set_dimensions(second_minor_index, most_minor); 248 | + } 249 | + return result; 250 | +} 251 | + 252 | +StatusOr RunHloRematerialization(int64 memory_limit_bytes, 253 | + HloModule* module) { 254 | + 255 | + auto sch = DefaultMemoryScheduler; 256 | + string scheduler_option = 257 | + module->config().debug_options().xla_rematerialization_scheduler(); 258 | + 259 | + if (scheduler_option == "postorder") { 260 | + sch = PostOrderMemoryScheduler; 261 | + } else if (scheduler_option == "DFS") { 262 | + sch = DFSMemoryScheduler; 263 | + } else if (scheduler_option == "list") { 264 | + sch = ListMemoryScheduler; 265 | + } 266 | + 267 | + HloMemoryScheduler scheduler( 268 | + [](const BufferValue& buffer) { return ByteSizeOf(buffer.shape()); }, 269 | + ComputationSchedulerToModuleScheduler( 270 | + sch 271 | + )); 272 | + 273 | + TF_RETURN_IF_ERROR(scheduler.Run(module).status()); 274 | + 275 | + RematerializationAlg alg = kStandardAlg; 276 | + string algorithm_option = 277 | + module->config().debug_options().xla_rematerialization_algorithm(); 278 | + 279 | + if (algorithm_option == "compress") { 280 | + alg = kCompressAlg; 281 | + } else if (algorithm_option == "standardcompress") { 282 | + alg = kStandardAndCompressAlg; 283 | + } else if (algorithm_option == "path") { 284 | + alg = kPathAlg; 285 | + } 286 | + 287 | + DumpHloModuleIfEnabled(*module, "before_remat"); 288 | + HloRematerialization remat(ByteSizeOf, memory_limit_bytes, 289 | + /*sizes=*/nullptr, ChooseCompactLayoutForShape); 290 | + remat.setAlgorithm(alg); 291 | + return remat.Run(module); 292 | +} 293 | + 294 | // Runs optimization passes on the given HLO module. 295 | Status impl::OptimizeHloModule(HloModule* hlo_module, 296 | se::StreamExecutor* stream_exec, 297 | @@ -379,6 +441,7 @@ Status impl::OptimizeHloModule(HloModule* hlo_module, 298 | // tuple/get-tuple-element pairs that TupleSimplifier fixes. 299 | pipeline.AddPass(); 300 | } 301 | + 302 | // CudnnConvRewriter, CudnnConvPaddingLegalization and 303 | // CudnnConvPadForTensorCores may add instructions which can be simplified 304 | // by constant folding. 305 | @@ -399,7 +462,8 @@ Status impl::OptimizeHloModule(HloModule* hlo_module, 306 | LayoutAssignment::InstructionCanChangeLayout, stream_exec); 307 | TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); 308 | } 309 | - 310 | + 311 | + if (!hlo_module->config().debug_options().xla_rematerialization_disable_cuda()) 312 | { 313 | HloPassPipeline pipeline("post-layout_assignment"); 314 | /* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after 315 | @@ -558,6 +622,22 @@ StatusOr> NVPTXCompiler::RunBackend( 316 | 317 | TF_RET_CHECK(stream_exec != nullptr); 318 | 319 | + // Rematerialization needs to be apply after all optimizations 320 | + if (module->config().debug_options().xla_use_hlo_rematerialization()) { 321 | + string mem_limit_s = 322 | + module->config().debug_options().xla_rematerialization_mem_limit(); 323 | + 324 | + LOG(WARNING) << "Starting rematerialization of "<< 325 | + module->name() << " with " << mem_limit_s << " bytes as mem limit"; 326 | + 327 | + int64_t mem_limit_u = std::stoull(mem_limit_s); 328 | + 329 | + StatusOr remat_result = 330 | + RunHloRematerialization(mem_limit_u, module.get()); 331 | + 332 | + TF_RETURN_IF_ERROR(remat_result.status()); 333 | + } 334 | + 335 | llvm::LLVMContext llvm_context; 336 | std::string buffer; 337 | llvm::raw_string_ostream error(buffer); 338 | @@ -605,7 +685,6 @@ StatusOr> NVPTXCompiler::RunBackend( 339 | &ir_emitter_context); 340 | 341 | TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals()); 342 | - 343 | { 344 | XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::RunBackend - IR emission"); 345 | TF_RETURN_IF_ERROR(entry_computation->Accept(&ir_emitter)); 346 | diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc 347 | index 603371d..e3f1386 100644 348 | --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc 349 | +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc 350 | @@ -19,6 +19,7 @@ limitations under the License. 351 | #include 352 | #include 353 | #include 354 | +#include 355 | 356 | #include "absl/container/flat_hash_map.h" 357 | #include "absl/container/flat_hash_set.h" 358 | @@ -66,14 +67,24 @@ bool IsRematerializable(const HloInstruction* instruction) { 359 | } 360 | } 361 | 362 | + // Don`t rematerialize instructions that are smaller than 1 MB. This improves 363 | + // rematerialization stability over different mem_limits budgets. 364 | + int small_node_limit = instruction->parent()->parent() 365 | + ->config().debug_options().xla_rematerialization_small_node_limit(); 366 | + if (small_node_limit !=0 && 367 | + ShapeUtil::ByteSizeOf(instruction->shape(),sizeof(void*)) 368 | + <= small_node_limit*1024*1024) { 369 | + return false; 370 | + } 371 | + 372 | // Don't rematerialize instructions with side effects or instructions which 373 | // cannot be cloned safely. 374 | switch (instruction->opcode()) { 375 | case HloOpcode::kCall: 376 | + case HloOpcode::kCustomCall: 377 | case HloOpcode::kConstant: 378 | case HloOpcode::kConditional: 379 | case HloOpcode::kAllReduce: 380 | - case HloOpcode::kCustomCall: 381 | case HloOpcode::kParameter: 382 | case HloOpcode::kWhile: 383 | return false; 384 | @@ -100,6 +111,17 @@ bool CanBeRematerialized( 385 | using BufferId = int64; 386 | using BufferIdList = absl::InlinedVector; 387 | 388 | +struct RematStrategy { 389 | + enum { 390 | + // Recompute the node at a later program point. 391 | + kRecompute, 392 | + // Change the layout into a compact form and uncompress it back at a later 393 | + // program point. 394 | + kCompress, 395 | + } kind; 396 | + Shape compact_shape; 397 | +}; 398 | + 399 | // We wrap HloInstruction* with an Item that holds auxiliary 400 | // per-instruction state. 401 | struct Item { 402 | @@ -117,6 +139,10 @@ struct Item { 403 | // The buffers defined by this instruction. 404 | BufferIdList buffers_defined; 405 | 406 | + // Output buffers of this instruction. This is used to track outputs by GTE 407 | + // instructions (where the instruction doesn't define a buffer). 408 | + BufferIdList buffers_output; 409 | + 410 | // The buffers used by this instruction. 411 | BufferIdList buffers_used; 412 | 413 | @@ -251,6 +277,32 @@ class InstructionList { 414 | return InsertBefore(to_insert, min_position_item); 415 | } 416 | 417 | + void InsertAfterInstructions(Item* to_insert, 418 | + absl::Span after_instructions) { 419 | + VLOG(3) << "InsertAfterInstructions: " << to_insert->instruction->name() 420 | + << " after {" 421 | + << absl::StrJoin(after_instructions, ", ", 422 | + [](string* out, Item* item) { 423 | + absl::StrAppend(out, item->instruction->name()); 424 | + }) 425 | + << "}"; 426 | + 427 | + // Find the max position number of any instruction in 428 | + // 'after_instructions'. 429 | + CHECK(!after_instructions.empty()); 430 | + Item* max_position_item = nullptr; 431 | + for (Item* item : after_instructions) { 432 | + if (max_position_item == nullptr || 433 | + item->position > max_position_item->position) { 434 | + max_position_item = item; 435 | + } 436 | + } 437 | + // No rematerializable instruction should be inserted at the end of the 438 | + // computation. 439 | + CHECK(max_position_item->next != nullptr); 440 | + InsertBeforeInstructions(to_insert, {max_position_item->next}); 441 | + } 442 | + 443 | void Blacklist(const HloInstruction* inst) { 444 | GetItem(inst)->blacklisted = true; 445 | } 446 | @@ -327,6 +379,7 @@ class MemoryUsageTracker { 447 | MemoryUsageTracker( 448 | const HloComputation* computation, 449 | const HloRematerialization::ShapeSizeFunction& size_function, 450 | + const HloRematerialization::CompactShapeFunction& compact_shape_function, 451 | const TuplePointsToAnalysis& points_to_analysis, 452 | const InstructionList& instruction_list); 453 | 454 | @@ -338,6 +391,22 @@ class MemoryUsageTracker { 455 | // EndInstruction memory for dead operand(s) is freed. 456 | Status BeginInstruction(Item* item); 457 | 458 | + int64 RematerializationCost(const HloInstruction* instruction, 459 | + int64 memory_reduced, int64 memory_limit_bytes) { 460 | + // If none of the users of 'instruction' have been placed in the sequence 461 | + // (as tracked by memory_tracker), then rematerialization of 'instruction' 462 | + // is a zero-cost move of 'instruction' in the sequence. 463 | + if (!absl::c_any_of( 464 | + instruction->users(), 465 | + [this](const HloInstruction* inst) { return IsPlaced(inst); })) { 466 | + return 0; 467 | + } 468 | + 469 | + CHECK_GT(memory_reduced, 0); 470 | + // Return the inverse of the benefit of rematerialization. 471 | + return memory_limit_bytes / memory_reduced; 472 | + } 473 | + 474 | // Finishes the placement of the current instruction. This frees any dead 475 | // operands or dead result of the instruction. This must be called after 476 | // each call to BeginInstruction. 477 | @@ -348,16 +417,28 @@ class MemoryUsageTracker { 478 | int64 MemoryReducedIfRematerialized(Item* item) const; 479 | 480 | // Returns the number of bytes that the current memory usage will be reduced 481 | + // if the given instruction is compact. 482 | + int64 MemoryReducedIfCompressed(Item* item, const Shape& compact_shape) const; 483 | + 484 | + // Returns the number of bytes that the current memory usage will be reduced 485 | // by if the given sequence of instructions is rematerialized. 486 | int64 MemoryReducedIfRematerialized(const absl::Span& items) const; 487 | 488 | + Status AddCompressInstructions(Item* original_item, Item* compressed_item, 489 | + Item* uncompressed_item); 490 | + 491 | // Adjusts memory usage to account for the rematerialization of 492 | // original_item for all remaining unplaced uses. The rematerialization 493 | // is remat_item. This method should be called after the HLO graph has 494 | - // been transformed (rematerialization instruction created and connected to 495 | - // uses). 496 | + // been transformed (rematerialization instruction created and connected 497 | + // to uses). 498 | Status AddRematerializedInstruction(Item* original_item, Item* remat_item); 499 | 500 | + std::pair PickRematerializationCandidate( 501 | + const RematerializationAlg, 502 | + const InstructionList& instruction_list, int64 memory_limit_bytes, 503 | + absl::flat_hash_map* remat_able); 504 | + 505 | // Returns whether the given instruction has been placed (BeginInstruction 506 | // has been called with 'instruction' as the argument). 507 | bool IsPlaced(const HloInstruction* instruction) const { 508 | @@ -390,6 +471,9 @@ class MemoryUsageTracker { 509 | // The materialized size of the buffer in bytes. 510 | const int64 size; 511 | 512 | + // Shape of the buffer. 513 | + Shape shape; 514 | + 515 | // Whether this buffer is live-out of the computation. 516 | bool live_out; 517 | 518 | @@ -412,19 +496,21 @@ class MemoryUsageTracker { 519 | } 520 | }; 521 | 522 | + // Get the compact shape of given hlo instruction. An internal cache is used 523 | + // to avoid computing the shape multiple times. 524 | + StatusOr GetCompactShape(const HloInstruction* hlo); 525 | + 526 | // Creates a Buffer representing the given logical buffer. The buffer is added 527 | // to buffers_ and a reference is returned. 528 | Buffer& CreateBufferFromLogicalBuffer( 529 | const LogicalBuffer* logical_buffer, 530 | - const TuplePointsToAnalysis& points_to_analysis, 531 | - const HloRematerialization::ShapeSizeFunction& size_function, 532 | - bool live_out) { 533 | + const TuplePointsToAnalysis& points_to_analysis, bool live_out) { 534 | bool has_indirect_uses = false; 535 | ItemList users = GetUsers(instruction_list_, logical_buffer, 536 | points_to_analysis, &has_indirect_uses); 537 | return NewBuffer(instruction_list_.GetItem(logical_buffer->instruction()), 538 | - size_function(logical_buffer->shape()), std::move(users), 539 | - live_out, has_indirect_uses); 540 | + logical_buffer->shape(), std::move(users), live_out, 541 | + has_indirect_uses); 542 | } 543 | 544 | // Create a new buffer representing a rematerialization of given buffer for 545 | @@ -438,7 +524,7 @@ class MemoryUsageTracker { 546 | for (Item* use : rematerialized_uses) { 547 | CHECK(!use->placed) << use->instruction->name(); 548 | } 549 | - return NewBuffer(remat_item, original_buffer.size, 550 | + return NewBuffer(remat_item, original_buffer.shape, 551 | std::move(rematerialized_uses), /*live_out=*/false, 552 | /*has_indirect_uses=*/false); 553 | } 554 | @@ -449,7 +535,8 @@ class MemoryUsageTracker { 555 | // different computation. 556 | int64 AllocatedSize(BufferId buffer_id) const { 557 | const Buffer& buffer = buffers_.at(buffer_id); 558 | - HloOpcode def_opcode = buffer.defining_instruction->instruction->opcode(); 559 | + HloInstruction* inst = buffer.defining_instruction->instruction; 560 | + HloOpcode def_opcode = inst->opcode(); 561 | if (buffer.live_out || def_opcode == HloOpcode::kParameter) { 562 | return 0; 563 | } else { 564 | @@ -473,7 +560,7 @@ class MemoryUsageTracker { 565 | return absl::c_linear_search(in_progress_uses, buffer_id); 566 | } 567 | 568 | - // Returns whether the given instruction is live at the current program 569 | + // Returns whether the given buffer is live at the current program 570 | // point. 571 | bool IsCurrentlyLive(BufferId buffer_id) const { 572 | const Buffer& buffer = buffers_[buffer_id]; 573 | @@ -481,13 +568,30 @@ class MemoryUsageTracker { 574 | buffer.unfinished_user_count > 0); 575 | } 576 | 577 | + // Returns whether the given instruction is live at the current program 578 | + // point. 579 | + bool IsInstructionCurrentlyLive(Item* instruction) const { 580 | + // If the instruction has not started yet, it is not alive. 581 | + if (!IsPlaced(instruction->instruction)) { 582 | + return false; 583 | + } 584 | + for (const HloInstruction* user : instruction->instruction->users()) { 585 | + if (!IsPlaced(user)) { 586 | + // If there is an unplaced user, consider this instruction currently 587 | + // live. 588 | + return true; 589 | + } 590 | + } 591 | + return false; 592 | + } 593 | + 594 | // Create a new buffer, add it to buffers_, and return a reference. 595 | - Buffer& NewBuffer(Item* defining_instruction, int64 size, ItemList&& users, 596 | - bool live_out, bool has_indirect_uses) { 597 | + Buffer& NewBuffer(Item* defining_instruction, const Shape& shape, 598 | + ItemList&& users, bool live_out, bool has_indirect_uses) { 599 | int buffer_id = buffers_.size(); 600 | - buffers_.push_back(Buffer{buffer_id, defining_instruction, size, live_out, 601 | - has_indirect_uses, users, 602 | - static_cast(users.size())}); 603 | + buffers_.push_back(Buffer{ 604 | + buffer_id, defining_instruction, size_function_(shape), shape, live_out, 605 | + has_indirect_uses, users, static_cast(users.size())}); 606 | return buffers_.back(); 607 | } 608 | 609 | @@ -498,6 +602,16 @@ class MemoryUsageTracker { 610 | // (BeginInstruction/EndInstruction calls). 611 | const InstructionList& instruction_list_; 612 | 613 | + // Size function returns the bytes of a given buffer. 614 | + const HloRematerialization::ShapeSizeFunction& size_function_; 615 | + 616 | + // Converts a shape into compact form, returns the same shape if a shape is 617 | + // already considered compact. 618 | + const HloRematerialization::CompactShapeFunction& compact_shape_function_; 619 | + 620 | + // A map that caches existing known compact shape for each instruction. 621 | + absl::flat_hash_map compact_shape_; 622 | + 623 | // Memory usage at the currently placed instruction. 624 | int64 memory_usage_ = 0; 625 | 626 | @@ -512,9 +626,13 @@ class MemoryUsageTracker { 627 | MemoryUsageTracker::MemoryUsageTracker( 628 | const HloComputation* computation, 629 | const HloRematerialization::ShapeSizeFunction& size_function, 630 | + const HloRematerialization::CompactShapeFunction& compact_shape_function, 631 | const TuplePointsToAnalysis& points_to_analysis, 632 | const InstructionList& instruction_list) 633 | - : computation_(computation), instruction_list_(instruction_list) { 634 | + : computation_(computation), 635 | + instruction_list_(instruction_list), 636 | + size_function_(size_function), 637 | + compact_shape_function_(compact_shape_function) { 638 | PointsToSet::BufferSet live_out_set = 639 | points_to_analysis.GetPointsToSet(computation_->root_instruction()) 640 | .CreateFlattenedSet(); 641 | @@ -556,7 +674,7 @@ MemoryUsageTracker::MemoryUsageTracker( 642 | } 643 | } else { 644 | buffer = &CreateBufferFromLogicalBuffer( 645 | - logical_buffer, points_to_analysis, size_function, 646 | + logical_buffer, points_to_analysis, 647 | ContainsKey(live_out_set, logical_buffer)); 648 | item->buffers_defined.push_back(buffer->id); 649 | for (Item* user : buffer->users) { 650 | @@ -566,6 +684,14 @@ MemoryUsageTracker::MemoryUsageTracker( 651 | 652 | logical_buffer_to_buffer_id[logical_buffer] = buffer->id; 653 | } 654 | + 655 | + // Trace the output of each instruction. This is so that we can properly 656 | + // track which outputs does GTEs have. 657 | + for (const LogicalBuffer* logical_buffer : 658 | + points_to_analysis.GetPointsToSet(instruction).CreateFlattenedSet()) { 659 | + item->buffers_output.push_back( 660 | + logical_buffer_to_buffer_id[logical_buffer]); 661 | + } 662 | } 663 | XLA_VLOG_LINES(10, ToString()); 664 | DCHECK(Check()); 665 | @@ -609,9 +735,9 @@ Status MemoryUsageTracker::EndInstruction() { 666 | << buffer.ToString() << " has negative unfinished use count."; 667 | if (buffer.unfinished_user_count == 0) { 668 | // Buffer is now dead. 669 | - VLOG(3) << " " << buffer.ToString() << " is now dead."; 670 | memory_usage_ -= AllocatedSize(buffer_id); 671 | - CHECK_GE(memory_usage_, 0); 672 | + // The memory usage can become negative inside the computation as we can 673 | + // free up the parameter space and reuse it for other tensors. 674 | } 675 | } 676 | 677 | @@ -620,9 +746,9 @@ Status MemoryUsageTracker::EndInstruction() { 678 | for (BufferId buffer_id : in_progress_item_->buffers_defined) { 679 | const Buffer& buffer = buffers_.at(buffer_id); 680 | if (buffer.unfinished_user_count == 0) { 681 | - VLOG(3) << " " << buffer.ToString() << " is immediately dead."; 682 | memory_usage_ -= AllocatedSize(buffer_id); 683 | - CHECK_GE(memory_usage_, 0); 684 | + // The memory usage can become negative inside the computation as we can 685 | + // free up the parameter space and reuse it for other tensors. 686 | } 687 | } 688 | 689 | @@ -637,6 +763,30 @@ Status MemoryUsageTracker::EndInstruction() { 690 | return Status::OK(); 691 | } 692 | 693 | +int64 MemoryUsageTracker::MemoryReducedIfCompressed( 694 | + Item* item, const Shape& compact_shape) const { 695 | + CHECK_NE(in_progress_item_, nullptr); 696 | + if (!item->placed || item == in_progress_item_) { 697 | + return 0; 698 | + } 699 | + 700 | + int64 memory_reduced = 0; 701 | + 702 | + // We only compress a single piece of an output at one time. 703 | + CHECK_EQ(item->buffers_output.size(), 1); 704 | + BufferId buffer_id = item->buffers_output[0]; 705 | + if (IsCurrentlyLive(buffer_id) && !IsInUse(buffer_id) && 706 | + IsInstructionCurrentlyLive(item)) { 707 | + const Buffer& buffer = buffers_.at(buffer_id); 708 | + memory_reduced += buffer.size; 709 | + 710 | + int64 compact_shape_size = size_function_(compact_shape); 711 | + // Account for buffers that are compressed after instruction. 712 | + memory_reduced -= compact_shape_size; 713 | + } 714 | + return memory_reduced; 715 | +} 716 | + 717 | int64 MemoryUsageTracker::MemoryReducedIfRematerialized(Item* item) const { 718 | CHECK_NE(in_progress_item_, nullptr); 719 | if (!item->placed || item == in_progress_item_) { 720 | @@ -736,6 +886,56 @@ int64 MemoryUsageTracker::MemoryReducedIfRematerialized( 721 | return memory_reduced; 722 | } 723 | 724 | +Status MemoryUsageTracker::AddCompressInstructions(Item* original_item, 725 | + Item* compressed_item, 726 | + Item* uncompressed_item) { 727 | + // Original buffer is now dead. 728 | + memory_usage_ -= size_function_(original_item->instruction->shape()); 729 | + // Compressed buffer is now alive. 730 | + memory_usage_ += size_function_(compressed_item->instruction->shape()); 731 | + 732 | + ItemList placed_users; 733 | + ItemList unplaced_users; 734 | + CHECK_EQ(original_item->buffers_output.size(), 1); 735 | + BufferId original_buffer_id = original_item->buffers_output[0]; 736 | + Buffer& original_buffer = buffers_.at(original_buffer_id); 737 | + for (Item* user : original_buffer.users) { 738 | + if (user->placed) { 739 | + CHECK(IsFinished(user)) << user->instruction->name(); 740 | + placed_users.push_back(user); 741 | + } else { 742 | + unplaced_users.push_back(user); 743 | + } 744 | + } 745 | + original_buffer.users = std::move(placed_users); 746 | + original_buffer.unfinished_user_count = 0; 747 | + original_buffer.users.push_back(compressed_item); 748 | + Buffer& compressed_buffer = 749 | + NewBuffer(compressed_item, compressed_item->instruction->shape(), 750 | + {uncompressed_item}, /*live_out=*/false, 751 | + /*has_indirect_uses=*/false); 752 | + compressed_item->buffers_used = original_item->buffers_output; 753 | + compressed_item->buffers_output = {compressed_buffer.id}; 754 | + compressed_item->buffers_defined.push_back(compressed_buffer.id); 755 | + 756 | + Buffer& uncompressed_buffer = 757 | + NewBuffer(uncompressed_item, uncompressed_item->instruction->shape(), 758 | + std::move(unplaced_users), /*live_out=*/false, 759 | + /*has_indirect_uses=*/false); 760 | + 761 | + uncompressed_item->buffers_used = {compressed_item->buffers_output[0]}; 762 | + uncompressed_item->buffers_output = {uncompressed_buffer.id}; 763 | + uncompressed_item->buffers_defined = {uncompressed_buffer.id}; 764 | + 765 | + for (Item* user : uncompressed_buffer.users) { 766 | + BufferIdList& buffers_used = user->buffers_used; 767 | + std::replace(buffers_used.begin(), buffers_used.end(), original_buffer_id, 768 | + uncompressed_buffer.id); 769 | + } 770 | + 771 | + return Status::OK(); 772 | +} 773 | + 774 | Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item, 775 | Item* remat_item) { 776 | VLOG(3) << "AddRematerializedInstruction: original_instruction = " 777 | @@ -831,6 +1031,17 @@ string MemoryUsageTracker::ToString() const { 778 | return output; 779 | } 780 | 781 | +StatusOr MemoryUsageTracker::GetCompactShape(const HloInstruction* hlo) { 782 | + auto it = compact_shape_.find(hlo); 783 | + if (it != compact_shape_.end()) { 784 | + return it->second; 785 | + } 786 | + const Shape& original_shape = hlo->shape(); 787 | + TF_ASSIGN_OR_RETURN(Shape min_shape, compact_shape_function_(original_shape)); 788 | + compact_shape_[hlo] = min_shape; 789 | + return min_shape; 790 | +} 791 | + 792 | bool MemoryUsageTracker::Check() const { 793 | auto elements_are_unique = [](const BufferIdList& vec) { 794 | return vec.size() == std::set(vec.begin(), vec.end()).size(); 795 | @@ -917,12 +1128,16 @@ int64 RematerializationCost(const HloInstruction* instruction, 796 | // candidate which reduce memory use at the program point of the current 797 | // instruction as indicated by memory_tracker. nullptr is returned if no 798 | // candidate can be found. 799 | -Item* PickRematerializationCandidate( 800 | - const MemoryUsageTracker& memory_tracker, 801 | +std::pair 802 | +MemoryUsageTracker::PickRematerializationCandidate( 803 | + const RematerializationAlg algorithm, 804 | const InstructionList& instruction_list, int64 memory_limit_bytes, 805 | absl::flat_hash_map* remat_able) { 806 | Item* best_item = nullptr; 807 | int64 best_cost = 0; 808 | + RematStrategy best_strategy; 809 | + 810 | + VLOG(5) << "Picking candidate"; 811 | 812 | // TODO(b/35244891): This is currently quadratic in the number of HLO 813 | // instructions. 814 | @@ -947,68 +1162,520 @@ Item* PickRematerializationCandidate( 815 | if (!CanBeRematerialized(candidate, remat_able)) { 816 | VLOG(5) << "candidate " << candidate->name() 817 | << " not viable: is not rematerializable"; 818 | + 819 | continue; 820 | } 821 | 822 | - // If any of the candidate's control successor has been placed, we need to 823 | - // skip this candidate. Otherwise we will violate control dependency. 824 | - bool control_successor_placed = 825 | - std::any_of(candidate->control_successors().begin(), 826 | - candidate->control_successors().end(), 827 | - [&memory_tracker](const HloInstruction* inst) { 828 | - return memory_tracker.IsPlaced(inst); 829 | - }); 830 | + if (item->buffers_output.size() == 1 && 831 | + (algorithm == RematerializationAlg::kCompressAlg || 832 | + algorithm == RematerializationAlg::kStandardAndCompressAlg)) { 833 | + // Only consider compressing single output instruction. 834 | + const Buffer& output_buffer = buffers_.at(item->buffers_output[0]); 835 | + 836 | + if (item->placed && item != in_progress_item_ && 837 | + !output_buffer.live_out) { 838 | + const Shape& original_shape = item->instruction->shape(); 839 | + if (original_shape.IsArray()) { 840 | + Shape compact_shape = GetCompactShape(item->instruction).ValueOrDie(); 841 | + const int64 memory_reduced = 842 | + MemoryReducedIfCompressed(item, compact_shape); 843 | + if (memory_reduced > 0) { 844 | + const int64 cost = memory_limit_bytes / memory_reduced; 845 | + if (best_item == nullptr || cost < best_cost) { 846 | + VLOG(3) << "candidate " << candidate->name() << "(" 847 | + << candidate->ToShortString() << ")" 848 | + << " now best when compressed into " 849 | + << compact_shape.ToString(true); 850 | + RematStrategy strategy; 851 | + strategy.kind = RematStrategy::kCompress; 852 | + best_strategy = strategy; 853 | + best_strategy.compact_shape = compact_shape; 854 | + best_item = item; 855 | + best_cost = cost; 856 | + } 857 | + } 858 | + } 859 | + } 860 | + } 861 | + 862 | + // If any of the candidate's control successor has been placed, we need 863 | + // to skip this candidate. Otherwise we will violate control dependency. 864 | + bool control_successor_placed = std::any_of( 865 | + candidate->control_successors().begin(), 866 | + candidate->control_successors().end(), 867 | + [this](const HloInstruction* inst) { return IsPlaced(inst); }); 868 | 869 | if (control_successor_placed) { 870 | continue; 871 | } 872 | 873 | - const int64 memory_reduced = 874 | - memory_tracker.MemoryReducedIfRematerialized(item); 875 | + if (algorithm == RematerializationAlg::kStandardAlg || 876 | + algorithm == RematerializationAlg::kStandardAndCompressAlg) { 877 | + const int64 memory_reduced = MemoryReducedIfRematerialized(item); 878 | 879 | - if (memory_reduced <= 0) { 880 | - VLOG(5) << "candidate " << candidate->name() 881 | - << " memory reduced = " << memory_reduced << " <= 0"; 882 | - continue; 883 | + if (memory_reduced > 0) { 884 | + const int cost = 885 | + RematerializationCost(candidate, memory_reduced, memory_limit_bytes); 886 | + 887 | + VLOG(5) << "candidate " << candidate->name() << ", memory reduced " 888 | + << memory_reduced << ", cost per byte " << cost; 889 | + 890 | + if (best_item == nullptr || cost < best_cost) { 891 | + VLOG(5) << "candidate " << candidate->name() << " now best"; 892 | + best_strategy.kind = RematStrategy::kRecompute; 893 | + best_item = item; 894 | + best_cost = cost; 895 | + } 896 | + } 897 | + } 898 | + } 899 | + return {best_item, best_strategy}; 900 | +} 901 | + 902 | +StatusOr DerematerializeInstruction(HloComputation* computation, 903 | + HloInstruction* source_node) { 904 | + 905 | + for (auto inst : computation->instructions()) { 906 | + if (inst->name().find(source_node->name() + ".remat") == 0) { 907 | + std::vector users = inst->users(); 908 | + for (HloInstruction* user : users) { 909 | + TF_RETURN_IF_ERROR(inst->ReplaceUseWith(user, source_node)); 910 | + } 911 | + } 912 | + } 913 | + return true; 914 | +} 915 | + 916 | +// Rematerialize the instruction source_node and change its use in target_user: 917 | +// before remat: 918 | +// ---> targe_user 919 | +// / 920 | +// source_node -------| 921 | +// \ 922 | +// ----> other users 923 | +// 924 | +// after remat: 925 | +// 926 | +// remat_copy ------> target_user 927 | +// 928 | +// source_node -----> other users 929 | +// 930 | +StatusOr RematerializeInstructionPath( 931 | + HloComputation* computation, Item* source_node, Item* target_user, 932 | + absl::flat_hash_set* remat_move_instructions, 933 | + InstructionList* instruction_list, int path_size, 934 | + std::vector &articulations_vector) { 935 | + 936 | + HloInstruction* source_node_inst = source_node->instruction; 937 | + 938 | + if (!IsRematerializable(source_node_inst) || 939 | + path_size == 0 || 940 | + source_node->blacklisted) { 941 | + return false; 942 | + } 943 | + 944 | + HloInstruction* remat_copy_inst = 945 | + computation->AddInstruction(source_node_inst->Clone("remat")); 946 | + 947 | + Item* remat_copy_item = instruction_list->CreateItem(remat_copy_inst); 948 | + 949 | + TF_RETURN_IF_ERROR(source_node_inst->ReplaceUseWith( 950 | + target_user->instruction, remat_copy_inst)); 951 | + 952 | + ItemList place_before; 953 | + place_before.push_back(source_node); 954 | + 955 | + instruction_list->InsertAfterInstructions(remat_copy_item, place_before); 956 | + remat_copy_item->placed = true; 957 | + 958 | + if (source_node_inst->users().empty()) { 959 | + if (ContainsKey(*remat_move_instructions, source_node_inst)) { 960 | + remat_copy_item->blacklisted = true; 961 | + } 962 | + remat_move_instructions->insert(remat_copy_inst); 963 | + } 964 | + 965 | + auto* inst_item = instruction_list->first(); 966 | + for (; inst_item != nullptr; inst_item = instruction_list->next(inst_item)) { 967 | + for (auto inst_item_use : inst_item->instruction->users()) { 968 | + if (inst_item_use == remat_copy_inst) { 969 | + RematerializeInstructionPath(computation, inst_item, remat_copy_item, 970 | + remat_move_instructions, instruction_list, path_size-1, 971 | + articulations_vector); 972 | + } 973 | + } 974 | + } 975 | + 976 | + return true; 977 | +} 978 | + 979 | +StatusOr RematerializeInstruction( 980 | + MemoryUsageTracker* memory_tracker, Item* best_item, 981 | + absl::flat_hash_set* remat_move_instructions, 982 | + InstructionList* instruction_list) { 983 | + HloInstruction* best = best_item->instruction; 984 | + VLOG(1) << "Rematerializing instruction " << best->name() << " (saving " 985 | + << HumanReadableNumBytes( 986 | + memory_tracker->MemoryReducedIfRematerialized(best_item)) 987 | + << ")"; 988 | + 989 | + int64 net_instructions_added = 0; 990 | + 991 | + HloComputation* computation = best->parent(); 992 | + 993 | + HloInstruction* remat = 994 | + computation->AddInstruction(best->Clone(/*suffix=*/"remat")); 995 | + 996 | + // Add control dependencies to the new operation. 997 | + for (auto successor : best->control_successors()) { 998 | + TF_RETURN_IF_ERROR(remat->AddControlDependencyTo(successor)); 999 | + } 1000 | + for (auto predecessor : best->control_predecessors()) { 1001 | + TF_RETURN_IF_ERROR(predecessor->AddControlDependencyTo(remat)); 1002 | + } 1003 | + 1004 | + Item* remat_item = instruction_list->CreateItem(remat); 1005 | + 1006 | + // Replace each remaining use of 'best' with the rematerialization. 1007 | + std::vector best_users_copy = best->users(); 1008 | + for (HloInstruction* user : best_users_copy) { 1009 | + if (!memory_tracker->IsPlaced(user)) { 1010 | + VLOG(2) << " Replacing use of " << best->name() << " in " << user->name() 1011 | + << " with " << remat->name(); 1012 | + TF_RETURN_IF_ERROR(best->ReplaceUseWith(user, remat)); 1013 | + } 1014 | + } 1015 | + 1016 | + // Account for the rematerialization in the memory tracker. 1017 | + TF_RETURN_IF_ERROR( 1018 | + memory_tracker->AddRematerializedInstruction(best_item, remat_item)); 1019 | + 1020 | + // Insert rematerialized instruction right before the earliest unplaced 1021 | + // use of the instruction *and* the earliest unplaced last use of any 1022 | + // operands of remat. Unplaced uses of the remat's operands are included 1023 | + // because we don't want to extend the live range of remat's operands as 1024 | + // this could increase memory usage. 1025 | + ItemList place_before; 1026 | + for (auto user : remat->users()) { 1027 | + place_before.push_back(instruction_list->GetItem(user)); 1028 | + } 1029 | + for (auto* operand : remat->operands()) { 1030 | + for (auto* operand_user : operand->users()) { 1031 | + if (operand_user != remat) { 1032 | + Item* operand_user_item = instruction_list->GetItem(operand_user); 1033 | + if (!operand_user_item->placed) { 1034 | + place_before.push_back(operand_user_item); 1035 | + } 1036 | + } 1037 | } 1038 | + } 1039 | + // Insert rematerialized instruction before any of its successors to 1040 | + // preserve ordering regarding control dependency. 1041 | + for (auto successor : remat->control_successors()) { 1042 | + Item* successor_item = instruction_list->GetItem(successor); 1043 | + // Assert to make sure we never remat an operation with control 1044 | + // successor already placed. 1045 | + CHECK(!successor_item->placed) << successor_item->instruction->name(); 1046 | + place_before.push_back(successor_item); 1047 | + } 1048 | + instruction_list->InsertBeforeInstructions(remat_item, place_before); 1049 | + 1050 | + // If the rematerialized instruction is dead then rematerialization is 1051 | + // essentially a move. Don't delete the instruction now because we don't 1052 | + // want duplicate HloInstruction* values during the course of the 1053 | + // transformation because we keep maps with HloInstruction* values as 1054 | + // keys. 1055 | + if (best->users().empty()) { 1056 | + VLOG(2) << best->name() << " is now dead"; 1057 | + if (ContainsKey(*remat_move_instructions, best)) { 1058 | + // Previously, 'best' was a rematerialization which killed the 1059 | + // instruction it was a copying of. Now 'remat' is a rematerialization 1060 | + // of 'best' and kills 'best'. Stop rematerializing this instruction 1061 | + // to avoid an infinite loop. 1062 | + instruction_list->Blacklist(remat); 1063 | + } 1064 | + remat_move_instructions->insert(remat); 1065 | + 1066 | + } else { 1067 | + net_instructions_added++; 1068 | + } 1069 | + return net_instructions_added; 1070 | +} 1071 | 1072 | - const int cost = RematerializationCost(candidate, memory_tracker, 1073 | - memory_reduced, memory_limit_bytes); 1074 | +StatusOr CompressInstruction(MemoryUsageTracker* memory_tracker, 1075 | + Item* best_item, const Shape& compact_shape, 1076 | + InstructionList* instruction_list) { 1077 | + HloInstruction* best = best_item->instruction; 1078 | + VLOG(5) << "Transposing instruction " << best->name() << " (saving " 1079 | + << HumanReadableNumBytes(memory_tracker->MemoryReducedIfCompressed( 1080 | + best_item, compact_shape)) 1081 | + << ") to" << compact_shape.ToString(true); 1082 | 1083 | - VLOG(5) << "candidate " << candidate->name() << ", memory reduced " 1084 | - << memory_reduced << ", cost per byte " << cost; 1085 | + HloComputation* computation = best->parent(); 1086 | 1087 | - if (best_item == nullptr || cost < best_cost) { 1088 | - VLOG(5) << "candidate " << candidate->name() << " now best"; 1089 | - best_item = item; 1090 | - best_cost = cost; 1091 | + HloInstruction* compressed = computation->AddInstruction( 1092 | + HloInstruction::CreateUnary(compact_shape, HloOpcode::kCopy, best)); 1093 | + 1094 | + HloInstruction* uncompressed = computation->AddInstruction( 1095 | + HloInstruction::CreateUnary(best->shape(), HloOpcode::kCopy, compressed)); 1096 | + 1097 | + Item* compressed_item = instruction_list->CreateItem(compressed); 1098 | + compressed_item->placed = true; 1099 | + 1100 | + Item* uncompressed_item = instruction_list->CreateItem(uncompressed); 1101 | + 1102 | + // Replace each remaining use of 'best' with the uncompressed. 1103 | + std::vector best_users_copy = best->users(); 1104 | + for (HloInstruction* user : best_users_copy) { 1105 | + if (!memory_tracker->IsPlaced(user)) { 1106 | + VLOG(5) << " Replacing use of " << best->name() << " in " << user->name() 1107 | + << " with " << uncompressed->name(); 1108 | + TF_RETURN_IF_ERROR(best->ReplaceUseWith(user, uncompressed)); 1109 | } 1110 | } 1111 | - return best_item; 1112 | + 1113 | + // Account for the rematerialization in the memory tracker. 1114 | + TF_RETURN_IF_ERROR(memory_tracker->AddCompressInstructions( 1115 | + best_item, compressed_item, uncompressed_item)); 1116 | + 1117 | + // Insert rematerialized instruction right before the earliest unplaced 1118 | + // use of the instruction. 1119 | + ItemList place_before; 1120 | + for (auto user : uncompressed->users()) { 1121 | + place_before.push_back(instruction_list->GetItem(user)); 1122 | + } 1123 | + 1124 | + instruction_list->Blacklist(compressed_item->instruction); 1125 | + instruction_list->Blacklist(uncompressed_item->instruction); 1126 | + 1127 | + instruction_list->InsertBeforeInstructions(uncompressed_item, place_before); 1128 | + 1129 | + instruction_list->InsertAfterInstructions(compressed_item, {best_item}); 1130 | + 1131 | + return 2; 1132 | } 1133 | 1134 | } // namespace 1135 | 1136 | -StatusOr HloRematerialization::ComputePeakMemory( 1137 | +static int64 ByteSizeOf(const Shape& shape) { 1138 | + return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); 1139 | +} 1140 | + 1141 | +Status HloRematerialization::DumpScheduleDotGraph(const HloComputation* computation, 1142 | + const HloInstructionSequence& order, std::ofstream& dotfile, int64 peak) { 1143 | + InstructionList instruction_list(order); 1144 | + MemoryUsageTracker tracker(computation, size_function_, 1145 | + compact_shape_function_, *points_to_analysis_, 1146 | + instruction_list); 1147 | + 1148 | + dotfile << "\t\tlabel=\"" << computation->name() << "\"\n"; 1149 | + dotfile << "\t\tnode[shape=box style=filled fontsize=8 fillcolor=\"0.0 0.0 1.0\"];\n"; 1150 | + 1151 | + dotfile << "\t\t{ rank=same "; 1152 | + std::set inst_toprint; 1153 | + for (auto* item = instruction_list.first(); item != nullptr; 1154 | + item = instruction_list.next(item)) { 1155 | + const HloInstruction* instruction = item->instruction; 1156 | + if (ByteSizeOf(instruction->shape()) > 512*1024) { 1157 | + string inst_name = instruction->name(); 1158 | + dotfile << "\"" << inst_name << "\" "; 1159 | + inst_toprint.insert(instruction); 1160 | + std::vector users = instruction->users(); 1161 | + for (HloInstruction* user : users) { 1162 | + string user_name = user->name(); 1163 | + inst_toprint.insert(user); 1164 | + dotfile << "\"" << user_name << "\" "; 1165 | + } 1166 | + } 1167 | + } 1168 | + dotfile << "}"; 1169 | + 1170 | + for (auto* item = instruction_list.first(); item != nullptr; 1171 | + item = instruction_list.next(item)) { 1172 | + 1173 | + const HloInstruction* instruction = item->instruction; 1174 | + string inst_name = instruction->name(); 1175 | + 1176 | + TF_RETURN_IF_ERROR(tracker.BeginInstruction(item)); 1177 | + item->placed = true; 1178 | + 1179 | + TF_ASSIGN_OR_RETURN(int64 callee_usage, 1180 | + CalledComputationsMemoryUsage(instruction)); 1181 | + 1182 | + int64 node_mem = tracker.memory_usage() + callee_usage; 1183 | + double ratio = node_mem/(double)peak; 1184 | + 1185 | + if (inst_toprint.count(instruction) != 0) { 1186 | + if (inst_name.find("remat") != std::string::npos) { 1187 | + dotfile << "\t\t\"" << inst_name << "\" [color=\"blue\", penwidth=2, fillcolor=\"0.0 "<< std::tanh(ratio)*1.2 <<" 1.0\"];\n"; 1188 | + } else { 1189 | + dotfile << "\t\t\"" << inst_name << "\" [fillcolor=\"0.0 "<< std::tanh(ratio)*1.2 <<" 1.0\"];\n"; 1190 | + } 1191 | + } 1192 | + 1193 | + // Print node in dot file 1194 | + 1195 | + if (ByteSizeOf(instruction->shape()) > 512*1024) { 1196 | + std::vector users = instruction->users(); 1197 | + for (HloInstruction* user : users) { 1198 | + string user_name = user->name(); 1199 | + 1200 | + dotfile << "\t\t\"" << inst_name << "\" -> \"" << user_name << "\" [penwidth = " << 100*ByteSizeOf(instruction->shape())/(double)peak << "];\n"; 1201 | + } 1202 | + } 1203 | + 1204 | + //tracker.memory_usage() + callee_usage; 1205 | + 1206 | + TF_RETURN_IF_ERROR(tracker.EndInstruction()); 1207 | + } 1208 | + return Status::OK(); 1209 | +} 1210 | + 1211 | +Status HloRematerialization::DumpModuleScheduleDotGraph(string prefix, HloModule* module, int64 peak) { 1212 | + std::ofstream dotfile(prefix+"."+module->name()+".dot"); 1213 | + 1214 | + dotfile << "digraph {\n"; 1215 | + module->clear_schedule(); 1216 | + 1217 | + HloDCE().Run(module); 1218 | + 1219 | + TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module)); 1220 | + 1221 | + HloMemoryScheduler scheduler( 1222 | + [](const BufferValue& buffer) { return ByteSizeOf(buffer.shape()); }, 1223 | + ComputationSchedulerToModuleScheduler( 1224 | + DefaultMemoryScheduler 1225 | + )); 1226 | + scheduler.Run(module); 1227 | + 1228 | + TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module)); 1229 | + 1230 | + call_graph_ = CallGraph::Build(module); 1231 | + 1232 | + TF_RETURN_IF_ERROR(call_graph_->VisitNodes( 1233 | + [this, module, &dotfile, peak](const CallGraphNode& node) -> Status { 1234 | + static int i = 0; 1235 | + if (node.context() == CallContext::kSequential) { 1236 | + dotfile << "\tsubgraph cluster_" << ++i << " {\n"; 1237 | + DumpScheduleDotGraph(node.computation(), 1238 | + module->schedule().sequence(node.computation()), dotfile, peak); 1239 | + dotfile << "\t}\n"; 1240 | + } 1241 | + return Status::OK(); 1242 | + }, 1243 | + /*visit_unreachable_nodes=*/false)); 1244 | + 1245 | + dotfile << "}\n"; 1246 | + 1247 | + dotfile.close(); 1248 | + 1249 | + return Status::OK(); 1250 | +} 1251 | + 1252 | +using ComputationPeak = HloRematerialization::ComputationPeak; 1253 | + 1254 | +StatusOr 1255 | +HloRematerialization::ComputeModulePeakMemory(HloModule* module, bool log=false) { 1256 | + module->clear_schedule(); 1257 | + HloDCE().Run(module); 1258 | + 1259 | + TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module)); 1260 | + 1261 | + std::ofstream* logfile = nullptr; 1262 | + 1263 | + if (log) { 1264 | + logfile = new std::ofstream("mem."+module->name()+".log"); 1265 | + } 1266 | + 1267 | + HloMemoryScheduler scheduler( 1268 | + [](const BufferValue& buffer) { return ByteSizeOf(buffer.shape()); }, 1269 | + ComputationSchedulerToModuleScheduler( 1270 | + DefaultMemoryScheduler 1271 | + )); 1272 | + scheduler.Run(module); 1273 | + 1274 | + TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module)); 1275 | + 1276 | + TF_RETURN_IF_ERROR(call_graph_->VisitNodes( 1277 | + [this, module, &logfile](const CallGraphNode& node) -> Status { 1278 | + if (node.context() == CallContext::kSequential) { 1279 | + if (logfile) 1280 | + *logfile << "Remating computation: " 1281 | + << node.computation()->name() << "\n"; 1282 | + 1283 | + TF_ASSIGN_OR_RETURN( 1284 | + ComputationPeak peak, 1285 | + ComputePeakMemory(node.computation(), 1286 | + module->schedule().sequence(node.computation()), 1287 | + logfile)); 1288 | + computation_peak_memory_[node.computation()] = peak.memory; 1289 | + } 1290 | + return Status::OK(); 1291 | + }, 1292 | + /*visit_unreachable_nodes=*/false)); 1293 | + 1294 | + if (logfile) { 1295 | + logfile->close(); 1296 | + delete logfile; 1297 | + } 1298 | + 1299 | + return computation_peak_memory_.at(module->entry_computation()); 1300 | +} 1301 | + 1302 | +StatusOr HloRematerialization::ComputePeakMemory( 1303 | const HloComputation* computation, 1304 | - const HloInstructionSequence& order) const { 1305 | + const HloInstructionSequence& order, 1306 | + std::ofstream* logfile = nullptr) const { 1307 | InstructionList instruction_list(order); 1308 | - MemoryUsageTracker tracker(computation, size_function_, *points_to_analysis_, 1309 | + MemoryUsageTracker tracker(computation, size_function_, 1310 | + compact_shape_function_, *points_to_analysis_, 1311 | instruction_list); 1312 | - int64 peak_memory = tracker.memory_usage(); 1313 | + ComputationPeak peak; 1314 | + peak.memory = tracker.memory_usage(); 1315 | + peak.instruction = instruction_list.first()->instruction; 1316 | + 1317 | + absl::flat_hash_map remat_able; 1318 | + 1319 | for (auto* item = instruction_list.first(); item != nullptr; 1320 | item = instruction_list.next(item)) { 1321 | + 1322 | const HloInstruction* instruction = item->instruction; 1323 | + std::string name = instruction->name(); 1324 | + 1325 | TF_RETURN_IF_ERROR(tracker.BeginInstruction(item)); 1326 | + item->placed = true; 1327 | + 1328 | TF_ASSIGN_OR_RETURN(int64 callee_usage, 1329 | CalledComputationsMemoryUsage(instruction)); 1330 | - peak_memory = 1331 | - std::max(peak_memory, tracker.memory_usage() + callee_usage); 1332 | + 1333 | + if (logfile) { 1334 | + *logfile << " " << name << " mem: " << 1335 | + HumanReadableNumBytes(tracker.memory_usage() + callee_usage) << 1336 | + " peak: " << HumanReadableNumBytes(peak.memory) << "\n"; 1337 | + Item* best_item; 1338 | + RematStrategy best_strategy; 1339 | + std::tie(best_item, best_strategy) = 1340 | + tracker.PickRematerializationCandidate( 1341 | + RematerializationAlg::kStandardAlg, 1342 | + instruction_list, 0, &remat_able); 1343 | + 1344 | + if (best_item) { 1345 | + *logfile << " Largest Alive: " << best_item->instruction->name() << " " 1346 | + << " size: " << HumanReadableNumBytes(ShapeUtil::ByteSizeOf(best_item->instruction->shape(), sizeof(void*))) 1347 | + << "\n"; 1348 | + } 1349 | + } 1350 | + 1351 | + if (tracker.memory_usage() + callee_usage > peak.memory) { 1352 | + peak.memory = tracker.memory_usage() + callee_usage; 1353 | + peak.instruction = item->instruction; 1354 | + } 1355 | + 1356 | TF_RETURN_IF_ERROR(tracker.EndInstruction()); 1357 | } 1358 | VLOG(1) << "Peak memory for " << computation->name() << ": " 1359 | - << HumanReadableNumBytes(peak_memory); 1360 | - return peak_memory; 1361 | + << HumanReadableNumBytes(peak.memory); 1362 | + return peak; 1363 | } 1364 | 1365 | StatusOr HloRematerialization::CalledComputationsMemoryUsage( 1366 | @@ -1026,6 +1693,92 @@ StatusOr HloRematerialization::CalledComputationsMemoryUsage( 1367 | return callee_usage; 1368 | } 1369 | 1370 | +StatusOr HloRematerialization::RematerializeComputationByPathes( 1371 | + HloComputation* computation, HloSchedule* schedule, 1372 | + int64 memory_limit_bytes, HloArticulationAnalysis &articulations) { 1373 | + bool changed = false; 1374 | + 1375 | + TF_ASSIGN_OR_RETURN(ComputationPeak peak, 1376 | + ComputePeakMemory(computation, schedule->sequence(computation))); 1377 | + 1378 | + auto articulations_vector = articulations.getArticulationsSortedByFlops(); 1379 | + 1380 | + LOG(WARNING) << "Number of articulations: " << articulations_vector.size(); 1381 | + 1382 | + auto* peak_inst = peak.instruction; 1383 | + 1384 | + absl::flat_hash_set remat_move_instructions; 1385 | + 1386 | + const CallGraphNode& call_graph_node = call_graph_->GetNode(computation); 1387 | + 1388 | + InstructionList instruction_list(schedule->sequence(computation)); 1389 | + 1390 | + int max_path_depth = 10; 1391 | + 1392 | + for (auto* item = instruction_list.first(); item != nullptr; 1393 | + item = instruction_list.next(item)) { 1394 | + HloInstruction* instruction = item->instruction; 1395 | + item->placed = true; 1396 | + if (instruction == peak_inst) break; 1397 | + } 1398 | + 1399 | + auto* item = instruction_list.first(); 1400 | + for (; item != nullptr; item = instruction_list.next(item)) { 1401 | + 1402 | + HloInstruction* instruction = item->instruction; 1403 | + if (!articulations.IsArticulation(instruction)) { 1404 | + continue; 1405 | + } 1406 | + 1407 | + item->placed = true; 1408 | + 1409 | + if (instruction == peak_inst) { 1410 | + break; 1411 | + } 1412 | + 1413 | + bool control_successor_placed = false; 1414 | + for (auto inst = instruction->control_successors().begin(); 1415 | + inst != instruction->control_successors().end(); inst++) { 1416 | + if (instruction_list.GetItem(*inst)->placed) { 1417 | + control_successor_placed = true; 1418 | + } 1419 | + } 1420 | + 1421 | + if (IsRematerializable(instruction) && !item->blacklisted && 1422 | + !control_successor_placed) { 1423 | + 1424 | + std::vector users = instruction->users(); 1425 | + for (HloInstruction* user : users) { 1426 | + auto* user_item = instruction_list.GetItem(user); 1427 | + if (!user_item->placed) { 1428 | + RematerializeInstructionPath(computation, item, user_item, 1429 | + &remat_move_instructions, &instruction_list, max_path_depth, 1430 | + articulations_vector); 1431 | + } 1432 | + } 1433 | + } 1434 | + 1435 | + const CallSite* callsite = call_graph_node.GetCallSite(instruction); 1436 | + if (callsite != nullptr && 1437 | + callsite->context() == CallContext::kSequential) { 1438 | + for (HloComputation* called_computation : 1439 | + callsite->called_computations()) { 1440 | + if (!ContainsKey(rematerialized_computations_, called_computation)) { 1441 | + TF_ASSIGN_OR_RETURN( 1442 | + bool subcomputation_changed, 1443 | + RematerializeComputationByPathes(called_computation, schedule, 1444 | + memory_limit_bytes, articulations)); 1445 | + changed |= subcomputation_changed; 1446 | + } 1447 | + } 1448 | + } 1449 | + } 1450 | + 1451 | + rematerialized_computations_.insert(computation); 1452 | + return changed; 1453 | +} 1454 | + 1455 | + 1456 | StatusOr HloRematerialization::RematerializeComputation( 1457 | HloComputation* computation, HloSchedule* schedule, 1458 | int64 memory_limit_bytes) { 1459 | @@ -1037,9 +1790,14 @@ StatusOr HloRematerialization::RematerializeComputation( 1460 | 1461 | InstructionList instruction_list(schedule->sequence(computation)); 1462 | MemoryUsageTracker memory_tracker(computation, size_function_, 1463 | + compact_shape_function_, 1464 | *points_to_analysis_, instruction_list); 1465 | bool changed = false; 1466 | 1467 | + TF_ASSIGN_OR_RETURN(ComputationPeak peak, 1468 | + ComputePeakMemory(computation, schedule->sequence(computation))); 1469 | + bool has_get_to_peak = false; 1470 | + 1471 | // If the rematerialization makes the source instruction dead, then the 1472 | // rematerialization is added to 'remat_move_instructions' (the 1473 | // rematerialization is essentially a move). If the next rematerialization of 1474 | @@ -1066,7 +1824,8 @@ StatusOr HloRematerialization::RematerializeComputation( 1475 | // (program point) if memory_usage exceeds the specified limit then 1476 | // rematerialize HLO instructions until memory_usage is reduced. 1477 | int64 instruction_index = 0; 1478 | - for (auto* item = instruction_list.first(); item != nullptr; 1479 | + for (auto* item = instruction_list.first(); 1480 | + item != nullptr; //instruction_list.next(peak_inst_item); 1481 | item = instruction_list.next(item)) { 1482 | const HloInstruction* instruction = item->instruction; 1483 | TF_ASSIGN_OR_RETURN(int64 callee_usage, 1484 | @@ -1086,8 +1845,11 @@ StatusOr HloRematerialization::RematerializeComputation( 1485 | callee_usage) 1486 | << ", limit is " << HumanReadableNumBytes(memory_limit_bytes); 1487 | 1488 | - Item* best_item = PickRematerializationCandidate( 1489 | - memory_tracker, instruction_list, memory_limit_bytes, &remat_able); 1490 | + Item* best_item; 1491 | + RematStrategy best_strategy; 1492 | + std::tie(best_item, best_strategy) = 1493 | + memory_tracker.PickRematerializationCandidate( 1494 | + remat_alg, instruction_list, memory_limit_bytes, &remat_able); 1495 | 1496 | if (best_item == nullptr) { 1497 | VLOG(3) << "Unable to find rematerialization candidate at program " 1498 | @@ -1099,88 +1861,33 @@ StatusOr HloRematerialization::RematerializeComputation( 1499 | } 1500 | 1501 | HloInstruction* best = best_item->instruction; 1502 | - VLOG(1) << "Rematerializing instruction " << best->name() << " (saving " 1503 | - << HumanReadableNumBytes( 1504 | - memory_tracker.MemoryReducedIfRematerialized(best_item)) 1505 | - << ")"; 1506 | changed = true; 1507 | remat_count++; 1508 | 1509 | - HloInstruction* remat = 1510 | - computation->AddInstruction(best->Clone(/*suffix=*/"remat")); 1511 | - 1512 | - // Add control dependencies to the new operation. 1513 | - for (auto successor : best->control_successors()) { 1514 | - TF_RETURN_IF_ERROR(remat->AddControlDependencyTo(successor)); 1515 | - } 1516 | - for (auto predecessor : best->control_predecessors()) { 1517 | - TF_RETURN_IF_ERROR(predecessor->AddControlDependencyTo(remat)); 1518 | - } 1519 | - 1520 | - Item* remat_item = instruction_list.CreateItem(remat); 1521 | - 1522 | - // Replace each remaining use of 'best' with the rematerialization. 1523 | - std::vector best_users_copy = best->users(); 1524 | - for (HloInstruction* user : best_users_copy) { 1525 | - if (!memory_tracker.IsPlaced(user)) { 1526 | - VLOG(2) << " Replacing use of " << best->name() << " in " 1527 | - << user->name() << " with " << remat->name(); 1528 | - TF_RETURN_IF_ERROR(best->ReplaceUseWith(user, remat)); 1529 | - } 1530 | - } 1531 | - 1532 | - // Account for the rematerialization in the memory tracker. 1533 | - TF_RETURN_IF_ERROR( 1534 | - memory_tracker.AddRematerializedInstruction(best_item, remat_item)); 1535 | - 1536 | - // Insert rematerialized instruction right before the earliest unplaced 1537 | - // use of the instruction *and* the earliest unplaced last use of any 1538 | - // operands of remat. Unplaced uses of the remat's operands are included 1539 | - // because we don't want to extend the live range of remat's operands as 1540 | - // this could increase memory usage. 1541 | - ItemList place_before; 1542 | - for (auto user : remat->users()) { 1543 | - place_before.push_back(instruction_list.GetItem(user)); 1544 | - } 1545 | - for (auto* operand : remat->operands()) { 1546 | - for (auto* operand_user : operand->users()) { 1547 | - if (operand_user != remat) { 1548 | - Item* operand_user_item = instruction_list.GetItem(operand_user); 1549 | - if (!operand_user_item->placed) { 1550 | - place_before.push_back(operand_user_item); 1551 | - } 1552 | - } 1553 | - } 1554 | - } 1555 | - // Insert rematerialized instruction before any of its successors to 1556 | - // preserve ordering regarding control dependency. 1557 | - for (auto successor : remat->control_successors()) { 1558 | - Item* successor_item = instruction_list.GetItem(successor); 1559 | - // Assert to make sure we never remat an operation with control 1560 | - // successor already placed. 1561 | - CHECK(!successor_item->placed) << successor_item->instruction->name(); 1562 | - place_before.push_back(successor_item); 1563 | - } 1564 | - instruction_list.InsertBeforeInstructions(remat_item, place_before); 1565 | - 1566 | - // If the rematerialized instruction is dead then rematerialization is 1567 | - // essentially a move. Don't delete the instruction now because we don't 1568 | - // want duplicate HloInstruction* values during the course of the 1569 | - // transformation because we keep maps with HloInstruction* values as 1570 | - // keys. 1571 | - if (best->users().empty()) { 1572 | - VLOG(2) << best->name() << " is now dead"; 1573 | - if (ContainsKey(remat_move_instructions, best)) { 1574 | - // Previously, 'best' was a rematerialization which killed the 1575 | - // instruction it was a copying of. Now 'remat' is a rematerialization 1576 | - // of 'best' and kills 'best'. Stop rematerializing this instruction 1577 | - // to avoid an infinite loop. 1578 | - instruction_list.Blacklist(remat); 1579 | - } 1580 | - remat_move_instructions.insert(remat); 1581 | + int64 added_instruction = 0; 1582 | + if (best_strategy.kind == RematStrategy::kCompress) { 1583 | + VLOG(1) << "Compressing instruction " << best->name() << " (saving " 1584 | + << HumanReadableNumBytes( 1585 | + memory_tracker.MemoryReducedIfCompressed( 1586 | + best_item, best_strategy.compact_shape)) 1587 | + << ")"; 1588 | + 1589 | + TF_ASSIGN_OR_RETURN(added_instruction, 1590 | + CompressInstruction(&memory_tracker, best_item, 1591 | + best_strategy.compact_shape, 1592 | + &instruction_list)); 1593 | } else { 1594 | - net_instructions_added++; 1595 | + VLOG(1) << "Rematerializing instruction " << best->name() << " (saving " 1596 | + << HumanReadableNumBytes( 1597 | + memory_tracker.MemoryReducedIfRematerialized(best_item)) 1598 | + << ")"; 1599 | + 1600 | + TF_ASSIGN_OR_RETURN(added_instruction, 1601 | + RematerializeInstruction(&memory_tracker, best_item, 1602 | + &remat_move_instructions, 1603 | + &instruction_list)); 1604 | } 1605 | + net_instructions_added += added_instruction; 1606 | 1607 | VLOG(1) << "memory_usage after rematerialization = " 1608 | << HumanReadableNumBytes(memory_tracker.memory_usage()); 1609 | @@ -1223,10 +1930,16 @@ StatusOr HloRematerialization::RematerializeComputation( 1610 | VLOG(3) << "peak memory usage = " << HumanReadableNumBytes(peak_memory); 1611 | 1612 | TF_RETURN_IF_ERROR(memory_tracker.EndInstruction()); 1613 | - } 1614 | + 1615 | + if (peak.instruction == instruction) { 1616 | + // only tries to reduce memory usage until the peak (normally the 1617 | + // transition between forward and backward passes). This helps reduce 1618 | + // the heuristic time. 1619 | + has_get_to_peak = true; 1620 | + } 1621 | + } 1622 | 1623 | // Verify some invariants on the memory tracker. 1624 | - CHECK_EQ(memory_tracker.memory_usage(), 0); 1625 | for (auto* instruction : computation->instructions()) { 1626 | CHECK(memory_tracker.IsPlaced(instruction)) << instruction->name(); 1627 | } 1628 | @@ -1258,6 +1971,54 @@ StatusOr HloRematerialization::RematerializeComputation( 1629 | return changed; 1630 | } 1631 | 1632 | +StatusOr HloArticulationAnalysis::Run(HloModule* module) { 1633 | + module->entry_computation()->root_instruction()->Accept(costof.get()); 1634 | + 1635 | + // Visit all computations and fill articulations 1636 | + HloSchedule saved_schedule = module->schedule(); 1637 | + 1638 | + auto call_graph_ = CallGraph::Build(module); 1639 | + TF_RETURN_IF_ERROR(call_graph_->VisitNodes( 1640 | + [this, module, saved_schedule](const CallGraphNode& node) -> Status { 1641 | + if (node.context() == CallContext::kSequential) { 1642 | + SearchForArticulationComputation(node.computation(), saved_schedule); 1643 | + } 1644 | + return Status::OK(); 1645 | + }, 1646 | + /*visit_unreachable_nodes=*/false)); 1647 | + 1648 | + return articulations_.size() != 0; 1649 | +} 1650 | + 1651 | +void HloArticulationAnalysis::SearchForArticulationComputation( 1652 | + HloComputation* computation, const HloSchedule& schedule) { 1653 | + InstructionList instruction_list(schedule.sequence(computation)); 1654 | + for (auto* I = instruction_list.first(); I != nullptr; I = instruction_list.next(I)) { 1655 | + discovery_time = 0; 1656 | + DFS(computation, I->instruction); 1657 | + } 1658 | +} 1659 | + 1660 | +void HloArticulationAnalysis::DFS(HloComputation* computation, HloInstruction* inst) { 1661 | + visited_.insert(inst); 1662 | + discovery_[inst] = low_[inst] = ++discovery_time; 1663 | + 1664 | + unsigned children = 0; 1665 | + 1666 | + for (HloInstruction* user : inst->users()) { 1667 | + if (visited_.count(user) == 0) { 1668 | + children++; 1669 | + parent_[user] = inst; 1670 | + DFS(computation, user); 1671 | + low_[inst] = std::min(low_[inst], low_[user]); 1672 | + if (((parent_.count(inst) == 0 && children > 1) || 1673 | + (parent_.count(inst) != 0 && low_[user] >= low_[inst]))) { 1674 | + articulations_.insert(inst); 1675 | + } 1676 | + } 1677 | + } 1678 | +} 1679 | + 1680 | StatusOr HloRematerialization::Run(HloModule* module) { 1681 | VLOG(1) << "HloRematerialization() with memory limit of " 1682 | << HumanReadableNumBytes(memory_limit_bytes_); 1683 | @@ -1272,6 +2033,7 @@ StatusOr HloRematerialization::Run(HloModule* module) { 1684 | TF_RET_CHECK(module->has_schedule()); 1685 | TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module)); 1686 | 1687 | + 1688 | // Adjust memory limit to account for the output of the entry 1689 | // computation. This is necessary because the per-computation accounting in 1690 | // MemoryUsageTracker do not include output as these are typically allocated 1691 | @@ -1281,11 +2043,7 @@ StatusOr HloRematerialization::Run(HloModule* module) { 1692 | module->result_shape(), 1693 | [&module_output_size, module, this](const Shape& subshape, 1694 | const ShapeIndex& output_index) { 1695 | - if (!module->input_output_alias_config().OutputHasAlias(output_index)) { 1696 | - // Only account for non-aliased outputs to avoid double counting a 1697 | - // parameter buffer twice. 1698 | - module_output_size += size_function_(subshape); 1699 | - } 1700 | + module_output_size += size_function_(subshape); 1701 | }); 1702 | 1703 | const int64 adjusted_memory_limit_bytes = 1704 | @@ -1301,10 +2059,11 @@ StatusOr HloRematerialization::Run(HloModule* module) { 1705 | [this, module](const CallGraphNode& node) -> Status { 1706 | if (node.context() == CallContext::kSequential) { 1707 | TF_ASSIGN_OR_RETURN( 1708 | - computation_peak_memory_[node.computation()], 1709 | - ComputePeakMemory(node.computation(), module->schedule().sequence( 1710 | - node.computation()))); 1711 | - } 1712 | + ComputationPeak peak, 1713 | + ComputePeakMemory(node.computation(), 1714 | + module->schedule().sequence(node.computation()))); 1715 | + computation_peak_memory_[node.computation()] = peak.memory; 1716 | + } 1717 | return Status::OK(); 1718 | }, 1719 | /*visit_unreachable_nodes=*/false)); 1720 | @@ -1314,40 +2073,112 @@ StatusOr HloRematerialization::Run(HloModule* module) { 1721 | // peak memory for a computation does not include the output as this is 1722 | // typically accounted for in the caller. 1723 | const int64 before_peak_memory = 1724 | - computation_peak_memory_.at(module->entry_computation()) + 1725 | - module_output_size; 1726 | + computation_peak_memory_.at(module->entry_computation()) + 1727 | + module_output_size; 1728 | VLOG(1) << "Peak memory usage of module (before): " 1729 | - << HumanReadableNumBytes(before_peak_memory); 1730 | + << HumanReadableNumBytes(before_peak_memory); 1731 | 1732 | - // Subcomputations called by the entry computation will also be 1733 | - // rematerialized. 1734 | - TF_ASSIGN_OR_RETURN( 1735 | - bool changed, 1736 | - RematerializeComputation(module->entry_computation(), &module->schedule(), 1737 | - adjusted_memory_limit_bytes)); 1738 | 1739 | - // Rematerialization can introduce dead code. This occurs if all uses of an 1740 | - // instruction are replaced with rematerializations of the instruction. 1741 | + if (module->config().debug_options().xla_rematerialization_dump_dot()) 1742 | + DumpModuleScheduleDotGraph("before", module, before_peak_memory); 1743 | 1744 | - // Stash away the schedule during copy insertion, to avoid validation failures 1745 | - // while the module is in flux. 1746 | - HloSchedule saved_schedule = module->schedule(); 1747 | - module->clear_schedule(); 1748 | - TF_ASSIGN_OR_RETURN(bool dead_code_removed, HloDCE().Run(module)); 1749 | - changed |= dead_code_removed; 1750 | - 1751 | - // After DCE, the module sequence may include instructions which no longer 1752 | - // exist. Update the schedule and restore it. 1753 | - TF_RETURN_IF_ERROR(saved_schedule.Update()); 1754 | - TF_RETURN_IF_ERROR(module->set_schedule(std::move(saved_schedule))); 1755 | - VLOG(1) << "Rematerialized " << instructions_rematerialized_ 1756 | - << " instructions in module " << module->name() << "; " 1757 | - << net_instructions_added_ << " net instructions added"; 1758 | - const int64 current_peak_memory = 1759 | + int64 current_peak_memory = before_peak_memory; 1760 | + int64 best_peak_memory = current_peak_memory; 1761 | + 1762 | + bool changed = false; 1763 | + if (remat_alg == RematerializationAlg::kPathAlg) { 1764 | + LOG(WARNING) << "Remating with PATH\n"; 1765 | + HloArticulationAnalysis articulations; 1766 | + articulations.Run(module); 1767 | + 1768 | + int64 last_peak_memory = current_peak_memory+1; 1769 | + while (last_peak_memory > current_peak_memory 1770 | + && current_peak_memory > memory_limit_bytes_) { 1771 | + TF_ASSIGN_OR_RETURN( 1772 | + changed, 1773 | + RematerializeComputationByPathes(module->entry_computation(), 1774 | + &module->schedule(), adjusted_memory_limit_bytes, articulations)); 1775 | + 1776 | + last_peak_memory = current_peak_memory; 1777 | + TF_ASSIGN_OR_RETURN(current_peak_memory, ComputeModulePeakMemory(module)); 1778 | + current_peak_memory += module_output_size; 1779 | + if (current_peak_memory < best_peak_memory) 1780 | + best_peak_memory = current_peak_memory; 1781 | + } 1782 | + 1783 | + auto articulation_vector = articulations.getArticulationsSortedByFlops(); 1784 | + for (auto i = 0; i < articulation_vector.size() && current_peak_memory < memory_limit_bytes_; i++) { 1785 | + DerematerializeInstruction( 1786 | + articulation_vector[i]->parent(), articulation_vector[i]); 1787 | + 1788 | + // Quickly tries to recompute the memory peak after derematerialization. 1789 | + // This is imprecise and it normally gives a higher value than real one. 1790 | + // However, we use it as a filter to only compute the real valuie 1791 | + // (which is expensive) when this points to a higher than budget peak. 1792 | + TF_RETURN_IF_ERROR(call_graph_->VisitNodes( 1793 | + [this, module](const CallGraphNode& node) -> Status { 1794 | + if (node.context() == CallContext::kSequential) { 1795 | + TF_ASSIGN_OR_RETURN( 1796 | + ComputationPeak peak, 1797 | + ComputePeakMemory(node.computation(), 1798 | + module->schedule().sequence(node.computation()))); 1799 | + computation_peak_memory_[node.computation()] = peak.memory; 1800 | + } 1801 | + return Status::OK(); 1802 | + }, 1803 | + /*visit_unreachable_nodes=*/false)); 1804 | + current_peak_memory = 1805 | + computation_peak_memory_.at(module->entry_computation()) + 1806 | + module_output_size; 1807 | + 1808 | + // If the first calculate peak is higher than the our memory budget, than 1809 | + // we recompute it with higher precision. 1810 | + if (current_peak_memory >= memory_limit_bytes_) { 1811 | + TF_ASSIGN_OR_RETURN(current_peak_memory, ComputeModulePeakMemory(module)); 1812 | + current_peak_memory += module_output_size; 1813 | + best_peak_memory = current_peak_memory; 1814 | + 1815 | + if (current_peak_memory >= memory_limit_bytes_) { 1816 | + break; 1817 | + } 1818 | + } 1819 | + } 1820 | + 1821 | + TF_ASSIGN_OR_RETURN(current_peak_memory, ComputeModulePeakMemory(module)); 1822 | + current_peak_memory += module_output_size; 1823 | + 1824 | + } else { 1825 | + // Subcomputations called by the entry computation will also be 1826 | + // rematerialized. 1827 | + TF_ASSIGN_OR_RETURN( 1828 | + changed, 1829 | + RematerializeComputation(module->entry_computation(), 1830 | + &module->schedule(), adjusted_memory_limit_bytes)); 1831 | + 1832 | + // Stash away the schedule during copy insertion, to avoid validation failures 1833 | + // while the module is in flux. 1834 | + HloSchedule saved_schedule = module->schedule(); 1835 | + module->clear_schedule(); 1836 | + TF_ASSIGN_OR_RETURN(bool dead_code_removed, HloDCE().Run(module)); 1837 | + changed |= dead_code_removed; 1838 | + 1839 | + // After DCE, the module sequence may include instructions which no longer 1840 | + // exist. Update the schedule and restore it. 1841 | + TF_RETURN_IF_ERROR(saved_schedule.Update()); 1842 | + TF_RETURN_IF_ERROR(module->set_schedule(std::move(saved_schedule))); 1843 | + 1844 | + current_peak_memory = 1845 | computation_peak_memory_.at(module->entry_computation()) + 1846 | module_output_size; 1847 | - VLOG(1) << "Peak memory usage of module now " 1848 | - << HumanReadableNumBytes(current_peak_memory) << " (" 1849 | + best_peak_memory = current_peak_memory; 1850 | + } 1851 | + 1852 | + if (module->config().debug_options().xla_rematerialization_dump_dot()) 1853 | + DumpModuleScheduleDotGraph("remat", module, before_peak_memory); 1854 | + 1855 | + LOG(WARNING) << "Peak memory usage of module now " 1856 | + << HumanReadableNumBytes(best_peak_memory) << " - " 1857 | + << HumanReadableNumBytes(module_output_size) << " - (" 1858 | << current_peak_memory << " bytes), was " 1859 | << HumanReadableNumBytes(before_peak_memory) << " (" 1860 | << before_peak_memory << " bytes)"; 1861 | @@ -1356,12 +2187,15 @@ StatusOr HloRematerialization::Run(HloModule* module) { 1862 | << HumanReadableNumBytes(reduced_peak_memory) << " (" 1863 | << reduced_peak_memory << " bytes)"; 1864 | 1865 | + if (module->config().debug_options().xla_rematerialization_dump_memlog()) 1866 | + ComputeModulePeakMemory(module, true); 1867 | + 1868 | if (sizes_ != nullptr) { 1869 | sizes_->before_bytes = before_peak_memory; 1870 | sizes_->after_bytes = current_peak_memory; 1871 | } 1872 | 1873 | - XLA_VLOG_LINES(3, "After HloRematerialization:\n" + module->ToString()); 1874 | + XLA_VLOG_LINES(5, "After HloRematerialization:\n" + module->ToString()); 1875 | 1876 | if (current_peak_memory > memory_limit_bytes_) { 1877 | LOG(WARNING) << absl::StrFormat( 1878 | diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h 1879 | index 350cf0f..6220237 100644 1880 | --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h 1881 | +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h 1882 | @@ -21,12 +21,68 @@ 1883 | #include "tensorflow/compiler/xla/service/hlo_computation.h" 1884 | #include "tensorflow/compiler/xla/service/hlo_instruction.h" 1885 | #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" 1886 | +#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" 1887 | #include "tensorflow/compiler/xla/service/hlo_module.h" 1888 | #include "tensorflow/compiler/xla/service/hlo_schedule.h" 1889 | #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" 1890 | +#include "tensorflow/compiler/xla/shape.h" 1891 | +#include "tensorflow/compiler/xla/statusor.h" 1892 | 1893 | namespace xla { 1894 | 1895 | +class HloArticulationAnalysis : public HloModulePass { 1896 | + public: 1897 | + explicit HloArticulationAnalysis() : costof(new HloCostAnalysis([](const Shape& shape) { 1898 | + return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); 1899 | + })) 1900 | + {} 1901 | + ~HloArticulationAnalysis() override = default; 1902 | + 1903 | + absl::string_view name() const override { return "articulation analysis"; } 1904 | + 1905 | + std::unique_ptr costof; 1906 | + 1907 | + StatusOr Run(HloModule* module) override; 1908 | + 1909 | + bool IsArticulation(HloInstruction* inst) { 1910 | + return articulations_.count(inst) != 0; 1911 | + } 1912 | + 1913 | + std::vector getArticulationsSortedByFlops() { 1914 | + std::vector articulations; 1915 | + articulations.insert(articulations.begin(), 1916 | + articulations_.begin(), articulations_.end()); 1917 | + std::sort(articulations.begin(), articulations.end(), 1918 | + [this](const HloInstruction* l, const HloInstruction* r) { 1919 | + auto costa = (float)(costof->flop_count(*l) + 1920 | + costof->transcendental_count(*l) * 10) / 1921 | + (costof->bytes_accessed(*l)+2); 1922 | + auto costb = (float)(costof->flop_count(*r) + 1923 | + costof->transcendental_count(*r) * 10) / 1924 | + (costof->bytes_accessed(*r)+2); 1925 | + return costa > costb; 1926 | + }); 1927 | + return articulations; 1928 | + } 1929 | + 1930 | + protected: 1931 | + std::set articulations_; 1932 | + 1933 | + void SearchForArticulationComputation(HloComputation*, const HloSchedule&); 1934 | + 1935 | + absl::flat_hash_set visited_; 1936 | + absl::flat_hash_map discovery_; 1937 | + absl::flat_hash_map low_; 1938 | + absl::flat_hash_map parent_; 1939 | + unsigned discovery_time; 1940 | + 1941 | + void DFS(HloComputation*, HloInstruction*); 1942 | +}; 1943 | + 1944 | +enum RematerializationAlg { 1945 | + kStandardAlg, kPathAlg, kCompressAlg, kStandardAndCompressAlg 1946 | +}; 1947 | + 1948 | // HLO pass which rematerializes instructions to reduce peak memory use, where 1949 | // memory use is defined as the total size of all live HLO instruction 1950 | // values. Parameters and constants are included in memory use estimates. 1951 | @@ -38,6 +94,14 @@ class HloRematerialization : public HloModulePass { 1952 | public: 1953 | using ShapeSizeFunction = std::function; 1954 | 1955 | + // Computation Peak Helper 1956 | + struct ComputationPeak { 1957 | + int64 memory; 1958 | + HloInstruction *instruction; 1959 | + }; 1960 | + 1961 | + using CompactShapeFunction = std::function(const Shape&)>; 1962 | + 1963 | // Helper struct that communicates the before / after sizes for the 1964 | // rematerialization process. 1965 | struct RematerializationSizes { 1966 | @@ -45,26 +109,39 @@ class HloRematerialization : public HloModulePass { 1967 | int64 after_bytes; 1968 | }; 1969 | 1970 | + static Shape DefaultCompactShapeFunction(const Shape& shape) { return shape; } 1971 | + 1972 | // Constructor parameters: 1973 | // 1974 | // size_function: Function which returns the size in bytes of the top-level 1975 | // buffer of the given shape. 1976 | // 1977 | // memory_limit_bytes: The threshold number of bytes to reduce memory use to 1978 | - // via rematerialization. 1979 | + // via rematerialization. Size of aliased outputs should be subtracted 1980 | + // from this. 1981 | // 1982 | // sizes: Pointer to data structure which records the peak memory usage of 1983 | // the HLO module before/after rematerialization. Value are set during 1984 | // Run(). Can be nullptr. 1985 | - HloRematerialization(const ShapeSizeFunction& size_function, 1986 | - int64 memory_limit_bytes, RematerializationSizes* sizes) 1987 | + // 1988 | + // compact_shape_function: Function which returns the compact form of a 1989 | + // shape. If nullptr is provided, an default identity function is used. 1990 | + explicit HloRematerialization( 1991 | + const ShapeSizeFunction& size_function, int64 memory_limit_bytes, 1992 | + RematerializationSizes* sizes, 1993 | + CompactShapeFunction compact_shape_function = nullptr) 1994 | : size_function_(size_function), 1995 | memory_limit_bytes_(memory_limit_bytes), 1996 | - sizes_(sizes) {} 1997 | - ~HloRematerialization() {} 1998 | + sizes_(sizes), 1999 | + compact_shape_function_(compact_shape_function == nullptr 2000 | + ? DefaultCompactShapeFunction 2001 | + : std::move(compact_shape_function)) {} 2002 | + ~HloRematerialization() override = default; 2003 | 2004 | absl::string_view name() const override { return "rematerialization"; } 2005 | 2006 | + void setAlgorithm(RematerializationAlg a) { remat_alg = a; } 2007 | + 2008 | // Runs rematerialization on the given module. Returns whether the module was 2009 | // changed. Requires that the module has a schedule set 2010 | // (HloModule::has_schedule() is true) before running. Returns whether any 2011 | @@ -74,6 +151,7 @@ class HloRematerialization : public HloModulePass { 2012 | StatusOr Run(HloModule* module) override; 2013 | 2014 | protected: 2015 | + 2016 | // Rematerializes instructions within the given computation. 'order' is the 2017 | // order in which the computation's instructions will be emitted in the 2018 | // backend. Rematerialized instructions will be added to the HLO computation 2019 | @@ -82,12 +160,30 @@ class HloRematerialization : public HloModulePass { 2020 | HloSchedule* schedule, 2021 | int64 memory_limit_bytes); 2022 | 2023 | + virtual StatusOr RematerializeComputationByPathes( 2024 | + HloComputation* computation, 2025 | + HloSchedule* schedule, 2026 | + int64 memory_limit_bytes, 2027 | + HloArticulationAnalysis&); 2028 | + 2029 | // Computes and returns the peak memory used by the given computation. The 2030 | // peak memory is the maximum total size of all live HLO instruction values at 2031 | // any program point. 'order' is the order in which the HLO instructions will 2032 | // be emitted which is used to determine lifespans of HLO values. 2033 | - StatusOr ComputePeakMemory(const HloComputation* computation, 2034 | - const HloInstructionSequence& order) const; 2035 | + StatusOr ComputePeakMemory(const HloComputation* computation, 2036 | + const HloInstructionSequence& order, 2037 | + std::ofstream*) const; 2038 | + 2039 | + 2040 | + StatusOr ComputeModulePeakMemory(HloModule*, bool); 2041 | + 2042 | + 2043 | + Status DumpScheduleDotGraph(const HloComputation* computation, 2044 | + const HloInstructionSequence& order, 2045 | + std::ofstream& dotfile, int64 peak); 2046 | + 2047 | + Status DumpModuleScheduleDotGraph(string, HloModule*, int64); 2048 | + 2049 | 2050 | // Returns the peak memory usage of the called computations for the given 2051 | // instruction. Zero is returned if the instruction calls no computations. 2052 | @@ -108,6 +204,10 @@ class HloRematerialization : public HloModulePass { 2053 | // module before/after rematerialization 2054 | RematerializationSizes* sizes_; 2055 | 2056 | + // Converts a shape into compact form, returns the same shape if a shape is 2057 | + // already considered compact. 2058 | + const CompactShapeFunction compact_shape_function_; 2059 | + 2060 | // Call graph of the hlo_module. 2061 | std::unique_ptr call_graph_; 2062 | 2063 | @@ -133,6 +233,8 @@ class HloRematerialization : public HloModulePass { 2064 | // uses of the original instruction and the original instruction is 2065 | // dead. Hence, no net instructions were added. 2066 | int64 net_instructions_added_ = 0; 2067 | + 2068 | + RematerializationAlg remat_alg = kStandardAlg; 2069 | }; 2070 | 2071 | } // namespace xla 2072 | diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto 2073 | index f20ff9a..4faa3a0 100644 2074 | --- a/tensorflow/compiler/xla/xla.proto 2075 | +++ b/tensorflow/compiler/xla/xla.proto 2076 | @@ -288,7 +288,17 @@ message DebugOptions { 2077 | // Blacklist for cuDNN convolutions. 2078 | string xla_gpu_cudnn_conv_blacklist_path = 128; 2079 | 2080 | - // Next id: 129 2081 | + // Rematerialization flags. 2082 | + bool xla_use_hlo_rematerialization = 129; 2083 | + string xla_rematerialization_mem_limit = 130; 2084 | + string xla_rematerialization_scheduler = 131; 2085 | + string xla_rematerialization_algorithm = 132; 2086 | + int32 xla_rematerialization_small_node_limit = 133; 2087 | + bool xla_rematerialization_disable_cuda = 134; 2088 | + bool xla_rematerialization_dump_dot = 135; 2089 | + bool xla_rematerialization_dump_memlog = 136; 2090 | + 2091 | + // Next id: 137 2092 | 2093 | // Extra options to pass to the compilation backend (e.g. LLVM); specific 2094 | // interpretation of these values is left to the backend. 2095 | --------------------------------------------------------------------------------